channeldb/migration_01_to_11: remove unused code
This commit is contained in:
parent
f5191440c5
commit
60503d6c44
@ -1,24 +0,0 @@
|
|||||||
channeldb
|
|
||||||
==========
|
|
||||||
|
|
||||||
[![Build Status](http://img.shields.io/travis/lightningnetwork/lnd.svg)](https://travis-ci.org/lightningnetwork/lnd)
|
|
||||||
[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/lightningnetwork/lnd/blob/master/LICENSE)
|
|
||||||
[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/lightningnetwork/lnd/channeldb)
|
|
||||||
|
|
||||||
The channeldb implements the persistent storage engine for `lnd` and
|
|
||||||
generically a data storage layer for the required state within the Lightning
|
|
||||||
Network. The backing storage engine is
|
|
||||||
[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store
|
|
||||||
based off of LMDB.
|
|
||||||
|
|
||||||
The package implements an object-oriented storage model with queries and
|
|
||||||
mutations flowing through a particular object instance rather than the database
|
|
||||||
itself. The storage implemented by the objects includes: open channels, past
|
|
||||||
commitment revocation states, the channel graph which includes authenticated
|
|
||||||
node and channel announcements, outgoing payments, and invoices
|
|
||||||
|
|
||||||
## Installation and Updating
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ go get -u github.com/lightningnetwork/lnd/channeldb
|
|
||||||
```
|
|
@ -1,149 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/tor"
|
|
||||||
)
|
|
||||||
|
|
||||||
type unknownAddrType struct{}
|
|
||||||
|
|
||||||
func (t unknownAddrType) Network() string { return "unknown" }
|
|
||||||
func (t unknownAddrType) String() string { return "unknown" }
|
|
||||||
|
|
||||||
var testIP4 = net.ParseIP("192.168.1.1")
|
|
||||||
var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329")
|
|
||||||
|
|
||||||
var addrTests = []struct {
|
|
||||||
expAddr net.Addr
|
|
||||||
serErr string
|
|
||||||
}{
|
|
||||||
// Valid addresses.
|
|
||||||
{
|
|
||||||
expAddr: &net.TCPAddr{
|
|
||||||
IP: testIP4,
|
|
||||||
Port: 12345,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &net.TCPAddr{
|
|
||||||
IP: testIP6,
|
|
||||||
Port: 65535,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &tor.OnionAddr{
|
|
||||||
OnionService: "3g2upl4pq6kufc4m.onion",
|
|
||||||
Port: 9735,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &tor.OnionAddr{
|
|
||||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
// Invalid addresses.
|
|
||||||
{
|
|
||||||
expAddr: unknownAddrType{},
|
|
||||||
serErr: ErrUnknownAddressType.Error(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &net.TCPAddr{
|
|
||||||
// Remove last byte of IPv4 address.
|
|
||||||
IP: testIP4[:len(testIP4)-1],
|
|
||||||
Port: 12345,
|
|
||||||
},
|
|
||||||
serErr: "unable to encode",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &net.TCPAddr{
|
|
||||||
// Add an extra byte of IPv4 address.
|
|
||||||
IP: append(testIP4, 0xff),
|
|
||||||
Port: 12345,
|
|
||||||
},
|
|
||||||
serErr: "unable to encode",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &net.TCPAddr{
|
|
||||||
// Remove last byte of IPv6 address.
|
|
||||||
IP: testIP6[:len(testIP6)-1],
|
|
||||||
Port: 65535,
|
|
||||||
},
|
|
||||||
serErr: "unable to encode",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &net.TCPAddr{
|
|
||||||
// Add an extra byte to the IPv6 address.
|
|
||||||
IP: append(testIP6, 0xff),
|
|
||||||
Port: 65535,
|
|
||||||
},
|
|
||||||
serErr: "unable to encode",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &tor.OnionAddr{
|
|
||||||
// Invalid suffix.
|
|
||||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion",
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
serErr: "invalid suffix",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &tor.OnionAddr{
|
|
||||||
// Invalid length.
|
|
||||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion",
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
serErr: "unknown onion service length",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
expAddr: &tor.OnionAddr{
|
|
||||||
// Invalid encoding.
|
|
||||||
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion",
|
|
||||||
Port: 80,
|
|
||||||
},
|
|
||||||
serErr: "illegal base32",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddrSerialization tests that the serialization method used by channeldb
|
|
||||||
// for net.Addr's works as intended.
|
|
||||||
func TestAddrSerialization(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
for _, test := range addrTests {
|
|
||||||
err := serializeAddr(&b, test.expAddr)
|
|
||||||
switch {
|
|
||||||
case err == nil && test.serErr != "":
|
|
||||||
t.Fatalf("expected serialization err for addr %v",
|
|
||||||
test.expAddr)
|
|
||||||
|
|
||||||
case err != nil && test.serErr == "":
|
|
||||||
t.Fatalf("unexpected serialization err for addr %v: %v",
|
|
||||||
test.expAddr, err)
|
|
||||||
|
|
||||||
case err != nil && !strings.Contains(err.Error(), test.serErr):
|
|
||||||
t.Fatalf("unexpected serialization err for addr %v, "+
|
|
||||||
"want: %v, got %v", test.expAddr, test.serErr,
|
|
||||||
err)
|
|
||||||
|
|
||||||
case err != nil:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, err := deserializeAddr(&b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to deserialize address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if addr.String() != test.expAddr.String() {
|
|
||||||
t.Fatalf("expected address %v after serialization, "+
|
|
||||||
"got %v", addr, test.expAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,12 +1,9 @@
|
|||||||
package migration_01_to_11
|
package migration_01_to_11
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -15,8 +12,6 @@ import (
|
|||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/lightningnetwork/lnd/input"
|
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/shachain"
|
"github.com/lightningnetwork/lnd/shachain"
|
||||||
@ -36,90 +31,6 @@ var (
|
|||||||
//
|
//
|
||||||
// TODO(roasbeef): flesh out comment
|
// TODO(roasbeef): flesh out comment
|
||||||
openChannelBucket = []byte("open-chan-bucket")
|
openChannelBucket = []byte("open-chan-bucket")
|
||||||
|
|
||||||
// chanInfoKey can be accessed within the bucket for a channel
|
|
||||||
// (identified by its chanPoint). This key stores all the static
|
|
||||||
// information for a channel which is decided at the end of the
|
|
||||||
// funding flow.
|
|
||||||
chanInfoKey = []byte("chan-info-key")
|
|
||||||
|
|
||||||
// chanCommitmentKey can be accessed within the sub-bucket for a
|
|
||||||
// particular channel. This key stores the up to date commitment state
|
|
||||||
// for a particular channel party. Appending a 0 to the end of this key
|
|
||||||
// indicates it's the commitment for the local party, and appending a 1
|
|
||||||
// to the end of this key indicates it's the commitment for the remote
|
|
||||||
// party.
|
|
||||||
chanCommitmentKey = []byte("chan-commitment-key")
|
|
||||||
|
|
||||||
// revocationStateKey stores their current revocation hash, our
|
|
||||||
// preimage producer and their preimage store.
|
|
||||||
revocationStateKey = []byte("revocation-state-key")
|
|
||||||
|
|
||||||
// dataLossCommitPointKey stores the commitment point received from the
|
|
||||||
// remote peer during a channel sync in case we have lost channel state.
|
|
||||||
dataLossCommitPointKey = []byte("data-loss-commit-point-key")
|
|
||||||
|
|
||||||
// closingTxKey points to a the closing tx that we broadcasted when
|
|
||||||
// moving the channel to state CommitBroadcasted.
|
|
||||||
closingTxKey = []byte("closing-tx-key")
|
|
||||||
|
|
||||||
// commitDiffKey stores the current pending commitment state we've
|
|
||||||
// extended to the remote party (if any). Each time we propose a new
|
|
||||||
// state, we store the information necessary to reconstruct this state
|
|
||||||
// from the prior commitment. This allows us to resync the remote party
|
|
||||||
// to their expected state in the case of message loss.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): rename to commit chain?
|
|
||||||
commitDiffKey = []byte("commit-diff-key")
|
|
||||||
|
|
||||||
// revocationLogBucket is dedicated for storing the necessary delta
|
|
||||||
// state between channel updates required to re-construct a past state
|
|
||||||
// in order to punish a counterparty attempting a non-cooperative
|
|
||||||
// channel closure. This key should be accessed from within the
|
|
||||||
// sub-bucket of a target channel, identified by its channel point.
|
|
||||||
revocationLogBucket = []byte("revocation-log-key")
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrNoCommitmentsFound is returned when a channel has not set
|
|
||||||
// commitment states.
|
|
||||||
ErrNoCommitmentsFound = fmt.Errorf("no commitments found")
|
|
||||||
|
|
||||||
// ErrNoChanInfoFound is returned when a particular channel does not
|
|
||||||
// have any channels state.
|
|
||||||
ErrNoChanInfoFound = fmt.Errorf("no chan info found")
|
|
||||||
|
|
||||||
// ErrNoRevocationsFound is returned when revocation state for a
|
|
||||||
// particular channel cannot be found.
|
|
||||||
ErrNoRevocationsFound = fmt.Errorf("no revocations found")
|
|
||||||
|
|
||||||
// ErrNoPendingCommit is returned when there is not a pending
|
|
||||||
// commitment for a remote party. A new commitment is written to disk
|
|
||||||
// each time we write a new state in order to be properly fault
|
|
||||||
// tolerant.
|
|
||||||
ErrNoPendingCommit = fmt.Errorf("no pending commits found")
|
|
||||||
|
|
||||||
// ErrInvalidCircuitKeyLen signals that a circuit key could not be
|
|
||||||
// decoded because the byte slice is of an invalid length.
|
|
||||||
ErrInvalidCircuitKeyLen = fmt.Errorf(
|
|
||||||
"length of serialized circuit key must be 16 bytes")
|
|
||||||
|
|
||||||
// ErrNoCommitPoint is returned when no data loss commit point is found
|
|
||||||
// in the database.
|
|
||||||
ErrNoCommitPoint = fmt.Errorf("no commit point found")
|
|
||||||
|
|
||||||
// ErrNoCloseTx is returned when no closing tx is found for a channel
|
|
||||||
// in the state CommitBroadcasted.
|
|
||||||
ErrNoCloseTx = fmt.Errorf("no closing tx found")
|
|
||||||
|
|
||||||
// ErrNoRestoredChannelMutation is returned when a caller attempts to
|
|
||||||
// mutate a channel that's been recovered.
|
|
||||||
ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " +
|
|
||||||
"channel state")
|
|
||||||
|
|
||||||
// ErrChanBorked is returned when a caller attempts to mutate a borked
|
|
||||||
// channel.
|
|
||||||
ErrChanBorked = fmt.Errorf("cannot mutate borked channel")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChannelType is an enum-like type that describes one of several possible
|
// ChannelType is an enum-like type that describes one of several possible
|
||||||
@ -136,30 +47,8 @@ const (
|
|||||||
// SingleFunder represents a channel wherein one party solely funds the
|
// SingleFunder represents a channel wherein one party solely funds the
|
||||||
// entire capacity of the channel.
|
// entire capacity of the channel.
|
||||||
SingleFunder ChannelType = 0
|
SingleFunder ChannelType = 0
|
||||||
|
|
||||||
// DualFunder represents a channel wherein both parties contribute
|
|
||||||
// funds towards the total capacity of the channel. The channel may be
|
|
||||||
// funded symmetrically or asymmetrically.
|
|
||||||
DualFunder ChannelType = 1
|
|
||||||
|
|
||||||
// SingleFunderTweakless is similar to the basic SingleFunder channel
|
|
||||||
// type, but it omits the tweak for one's key in the commitment
|
|
||||||
// transaction of the remote party.
|
|
||||||
SingleFunderTweakless ChannelType = 2
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// IsSingleFunder returns true if the channel type if one of the known single
|
|
||||||
// funder variants.
|
|
||||||
func (c ChannelType) IsSingleFunder() bool {
|
|
||||||
return c == SingleFunder || c == SingleFunderTweakless
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsTweakless returns true if the target channel uses a commitment that
|
|
||||||
// doesn't tweak the key for the remote party.
|
|
||||||
func (c ChannelType) IsTweakless() bool {
|
|
||||||
return c == SingleFunderTweakless
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelConstraints represents a set of constraints meant to allow a node to
|
// ChannelConstraints represents a set of constraints meant to allow a node to
|
||||||
// limit their exposure, enact flow control and ensure that all HTLCs are
|
// limit their exposure, enact flow control and ensure that all HTLCs are
|
||||||
// economically relevant. This struct will be mirrored for both sides of the
|
// economically relevant. This struct will be mirrored for both sides of the
|
||||||
@ -444,10 +333,6 @@ type OpenChannel struct {
|
|||||||
// negotiate fees, or close the channel.
|
// negotiate fees, or close the channel.
|
||||||
IsInitiator bool
|
IsInitiator bool
|
||||||
|
|
||||||
// chanStatus is the current status of this channel. If it is not in
|
|
||||||
// the state Default, it should not be used for forwarding payments.
|
|
||||||
chanStatus ChannelStatus
|
|
||||||
|
|
||||||
// FundingBroadcastHeight is the height in which the funding
|
// FundingBroadcastHeight is the height in which the funding
|
||||||
// transaction was broadcast. This value can be used by higher level
|
// transaction was broadcast. This value can be used by higher level
|
||||||
// sub-systems to determine if a channel is stale and/or should have
|
// sub-systems to determine if a channel is stale and/or should have
|
||||||
@ -519,11 +404,6 @@ type OpenChannel struct {
|
|||||||
// implementation of secret store is shachain store.
|
// implementation of secret store is shachain store.
|
||||||
RevocationStore shachain.Store
|
RevocationStore shachain.Store
|
||||||
|
|
||||||
// Packager is used to create and update forwarding packages for this
|
|
||||||
// channel, which encodes all necessary information to recover from
|
|
||||||
// failures and reforward HTLCs that were not fully processed.
|
|
||||||
Packager FwdPackager
|
|
||||||
|
|
||||||
// FundingTxn is the transaction containing this channel's funding
|
// FundingTxn is the transaction containing this channel's funding
|
||||||
// outpoint. Upon restarts, this txn will be rebroadcast if the channel
|
// outpoint. Upon restarts, this txn will be rebroadcast if the channel
|
||||||
// is found to be pending.
|
// is found to be pending.
|
||||||
@ -548,657 +428,6 @@ func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID {
|
|||||||
return c.ShortChannelID
|
return c.ShortChannelID
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChanStatus returns the current ChannelStatus of this channel.
|
|
||||||
func (c *OpenChannel) ChanStatus() ChannelStatus {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
return c.chanStatus
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyChanStatus allows the caller to modify the internal channel state in a
|
|
||||||
// thead-safe manner.
|
|
||||||
func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.putChanStatus(status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClearChanStatus allows the caller to clear a particular channel status from
|
|
||||||
// the primary channel status bit field. After this method returns, a call to
|
|
||||||
// HasChanStatus(status) should return false.
|
|
||||||
func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.clearChanStatus(status)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasChanStatus returns true if the internal bitfield channel status of the
|
|
||||||
// target channel has the specified status bit set.
|
|
||||||
func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
return c.hasChanStatus(status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool {
|
|
||||||
return c.chanStatus&status == status
|
|
||||||
}
|
|
||||||
|
|
||||||
// RefreshShortChanID updates the in-memory short channel ID using the latest
|
|
||||||
// value observed on disk.
|
|
||||||
func (c *OpenChannel) RefreshShortChanID() error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
var sid lnwire.ShortChannelID
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sid = channel.ShortChannelID
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.ShortChannelID = sid
|
|
||||||
c.Packager = NewChannelPackager(sid)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchChanBucket is a helper function that returns the bucket where a
|
|
||||||
// channel's data resides in given: the public key for the node, the outpoint,
|
|
||||||
// and the chainhash that the channel resides on.
|
|
||||||
func fetchChanBucket(tx *bbolt.Tx, nodeKey *btcec.PublicKey,
|
|
||||||
outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bbolt.Bucket, error) {
|
|
||||||
|
|
||||||
// First fetch the top level bucket which stores all data related to
|
|
||||||
// current, active channels.
|
|
||||||
openChanBucket := tx.Bucket(openChannelBucket)
|
|
||||||
if openChanBucket == nil {
|
|
||||||
return nil, ErrNoChanDBExists
|
|
||||||
}
|
|
||||||
|
|
||||||
// Within this top level bucket, fetch the bucket dedicated to storing
|
|
||||||
// open channel data specific to the remote node.
|
|
||||||
nodePub := nodeKey.SerializeCompressed()
|
|
||||||
nodeChanBucket := openChanBucket.Bucket(nodePub)
|
|
||||||
if nodeChanBucket == nil {
|
|
||||||
return nil, ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll then recurse down an additional layer in order to fetch the
|
|
||||||
// bucket for this particular chain.
|
|
||||||
chainBucket := nodeChanBucket.Bucket(chainHash[:])
|
|
||||||
if chainBucket == nil {
|
|
||||||
return nil, ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the bucket for the node and chain fetched, we can now go down
|
|
||||||
// another level, for this channel itself.
|
|
||||||
var chanPointBuf bytes.Buffer
|
|
||||||
if err := writeOutpoint(&chanPointBuf, outPoint); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
chanBucket := chainBucket.Bucket(chanPointBuf.Bytes())
|
|
||||||
if chanBucket == nil {
|
|
||||||
return nil, ErrChannelNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanBucket, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fullSync syncs the contents of an OpenChannel while re-using an existing
|
|
||||||
// database transaction.
|
|
||||||
func (c *OpenChannel) fullSync(tx *bbolt.Tx) error {
|
|
||||||
// First fetch the top level bucket which stores all data related to
|
|
||||||
// current, active channels.
|
|
||||||
openChanBucket, err := tx.CreateBucketIfNotExists(openChannelBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Within this top level bucket, fetch the bucket dedicated to storing
|
|
||||||
// open channel data specific to the remote node.
|
|
||||||
nodePub := c.IdentityPub.SerializeCompressed()
|
|
||||||
nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll then recurse down an additional layer in order to fetch the
|
|
||||||
// bucket for this particular chain.
|
|
||||||
chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the bucket for the node fetched, we can now go down another
|
|
||||||
// level, creating the bucket for this channel itself.
|
|
||||||
var chanPointBuf bytes.Buffer
|
|
||||||
if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chanBucket, err := chainBucket.CreateBucket(
|
|
||||||
chanPointBuf.Bytes(),
|
|
||||||
)
|
|
||||||
switch {
|
|
||||||
case err == bbolt.ErrBucketExists:
|
|
||||||
// If this channel already exists, then in order to avoid
|
|
||||||
// overriding it, we'll return an error back up to the caller.
|
|
||||||
return ErrChanAlreadyExists
|
|
||||||
case err != nil:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return putOpenChannel(chanBucket, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkAsOpen marks a channel as fully open given a locator that uniquely
|
|
||||||
// describes its location within the chain.
|
|
||||||
func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
if err := c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.IsPending = false
|
|
||||||
channel.ShortChannelID = openLoc
|
|
||||||
|
|
||||||
return putOpenChannel(chanBucket, channel)
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.IsPending = false
|
|
||||||
c.ShortChannelID = openLoc
|
|
||||||
c.Packager = NewChannelPackager(openLoc)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the
|
|
||||||
// passed commitPoint for use to retrieve funds in case the remote force closes
|
|
||||||
// the channel.
|
|
||||||
func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := WriteElement(&b, commitPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
putCommitPoint := func(chanBucket *bbolt.Bucket) error {
|
|
||||||
return chanBucket.Put(dataLossCommitPointKey, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DataLossCommitPoint retrieves the stored commit point set during
|
|
||||||
// MarkDataLoss. If not found ErrNoCommitPoint is returned.
|
|
||||||
func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) {
|
|
||||||
var commitPoint *btcec.PublicKey
|
|
||||||
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
|
|
||||||
return ErrNoCommitPoint
|
|
||||||
default:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
bs := chanBucket.Get(dataLossCommitPointKey)
|
|
||||||
if bs == nil {
|
|
||||||
return ErrNoCommitPoint
|
|
||||||
}
|
|
||||||
r := bytes.NewReader(bs)
|
|
||||||
if err := ReadElements(r, &commitPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return commitPoint, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkBorked marks the event when the channel as reached an irreconcilable
|
|
||||||
// state, such as a channel breach or state desynchronization. Borked channels
|
|
||||||
// should never be added to the switch.
|
|
||||||
func (c *OpenChannel) MarkBorked() error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.putChanStatus(ChanStatusBorked)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChanSyncMsg returns the ChannelReestablish message that should be sent upon
|
|
||||||
// reconnection with the remote peer that we're maintaining this channel with.
|
|
||||||
// The information contained within this message is necessary to re-sync our
|
|
||||||
// commitment chains in the case of a last or only partially processed message.
|
|
||||||
// When the remote party receiver this message one of three things may happen:
|
|
||||||
//
|
|
||||||
// 1. We're fully synced and no messages need to be sent.
|
|
||||||
// 2. We didn't get the last CommitSig message they sent, to they'll re-send
|
|
||||||
// it.
|
|
||||||
// 3. We didn't get the last RevokeAndAck message they sent, so they'll
|
|
||||||
// re-send it.
|
|
||||||
//
|
|
||||||
// If this is a restored channel, having status ChanStatusRestored, then we'll
|
|
||||||
// modify our typical chan sync message to ensure they force close even if
|
|
||||||
// we're on the very first state.
|
|
||||||
func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
// The remote commitment height that we'll send in the
|
|
||||||
// ChannelReestablish message is our current commitment height plus
|
|
||||||
// one. If the receiver thinks that our commitment height is actually
|
|
||||||
// *equal* to this value, then they'll re-send the last commitment that
|
|
||||||
// they sent but we never fully processed.
|
|
||||||
localHeight := c.LocalCommitment.CommitHeight
|
|
||||||
nextLocalCommitHeight := localHeight + 1
|
|
||||||
|
|
||||||
// The second value we'll send is the height of the remote commitment
|
|
||||||
// from our PoV. If the receiver thinks that their height is actually
|
|
||||||
// *one plus* this value, then they'll re-send their last revocation.
|
|
||||||
remoteChainTipHeight := c.RemoteCommitment.CommitHeight
|
|
||||||
|
|
||||||
// If this channel has undergone a commitment update, then in order to
|
|
||||||
// prove to the remote party our knowledge of their prior commitment
|
|
||||||
// state, we'll also send over the last commitment secret that the
|
|
||||||
// remote party sent.
|
|
||||||
var lastCommitSecret [32]byte
|
|
||||||
if remoteChainTipHeight != 0 {
|
|
||||||
remoteSecret, err := c.RevocationStore.LookUp(
|
|
||||||
remoteChainTipHeight - 1,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
lastCommitSecret = [32]byte(*remoteSecret)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Additionally, we'll send over the current unrevoked commitment on
|
|
||||||
// our local commitment transaction.
|
|
||||||
currentCommitSecret, err := c.RevocationProducer.AtIndex(
|
|
||||||
localHeight,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've restored this channel, then we'll purposefully give them an
|
|
||||||
// invalid LocalUnrevokedCommitPoint so they'll force close the channel
|
|
||||||
// allowing us to sweep our funds.
|
|
||||||
if c.hasChanStatus(ChanStatusRestored) {
|
|
||||||
currentCommitSecret[0] ^= 1
|
|
||||||
|
|
||||||
// If this is a tweakless channel, then we'll purposefully send
|
|
||||||
// a next local height taht's invalid to trigger a force close
|
|
||||||
// on their end. We do this as tweakless channels don't require
|
|
||||||
// that the commitment point is valid, only that it's present.
|
|
||||||
if c.ChanType.IsTweakless() {
|
|
||||||
nextLocalCommitHeight = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &lnwire.ChannelReestablish{
|
|
||||||
ChanID: lnwire.NewChanIDFromOutPoint(
|
|
||||||
&c.FundingOutpoint,
|
|
||||||
),
|
|
||||||
NextLocalCommitHeight: nextLocalCommitHeight,
|
|
||||||
RemoteCommitTailHeight: remoteChainTipHeight,
|
|
||||||
LastRemoteCommitSecret: lastCommitSecret,
|
|
||||||
LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint(
|
|
||||||
currentCommitSecret[:],
|
|
||||||
),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isBorked returns true if the channel has been marked as borked in the
|
|
||||||
// database. This requires an existing database transaction to already be
|
|
||||||
// active.
|
|
||||||
//
|
|
||||||
// NOTE: The primary mutex should already be held before this method is called.
|
|
||||||
func (c *OpenChannel) isBorked(chanBucket *bbolt.Bucket) (bool, error) {
|
|
||||||
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return channel.chanStatus != ChanStatusDefault, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkCommitmentBroadcasted marks the channel as a commitment transaction has
|
|
||||||
// been broadcast, either our own or the remote, and we should watch the chain
|
|
||||||
// for it to confirm before taking any further action. It takes as argument the
|
|
||||||
// closing tx _we believe_ will appear in the chain. This is only used to
|
|
||||||
// republish this tx at startup to ensure propagation, and we should still
|
|
||||||
// handle the case where a different tx actually hits the chain.
|
|
||||||
func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := WriteElement(&b, closeTx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
putClosingTx := func(chanBucket *bbolt.Bucket) error {
|
|
||||||
return chanBucket.Put(closingTxKey, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.putChanStatus(ChanStatusCommitBroadcasted, putClosingTx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BroadcastedCommitment retrieves the stored closing tx set during
|
|
||||||
// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned.
|
|
||||||
func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) {
|
|
||||||
var closeTx *wire.MsgTx
|
|
||||||
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
|
|
||||||
return ErrNoCloseTx
|
|
||||||
default:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
bs := chanBucket.Get(closingTxKey)
|
|
||||||
if bs == nil {
|
|
||||||
return ErrNoCloseTx
|
|
||||||
}
|
|
||||||
r := bytes.NewReader(bs)
|
|
||||||
return ReadElement(r, &closeTx)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return closeTx, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// putChanStatus appends the given status to the channel. fs is an optional
|
|
||||||
// list of closures that are given the chanBucket in order to atomically add
|
|
||||||
// extra information together with the new status.
|
|
||||||
func (c *OpenChannel) putChanStatus(status ChannelStatus,
|
|
||||||
fs ...func(*bbolt.Bucket) error) error {
|
|
||||||
|
|
||||||
if err := c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add this status to the existing bitvector found in the DB.
|
|
||||||
status = channel.chanStatus | status
|
|
||||||
channel.chanStatus = status
|
|
||||||
|
|
||||||
if err := putOpenChannel(chanBucket, channel); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, f := range fs {
|
|
||||||
if err := f(chanBucket); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the in-memory representation to keep it in sync with the DB.
|
|
||||||
c.chanStatus = status
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *OpenChannel) clearChanStatus(status ChannelStatus) error {
|
|
||||||
if err := c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unset this bit in the bitvector on disk.
|
|
||||||
status = channel.chanStatus & ^status
|
|
||||||
channel.chanStatus = status
|
|
||||||
|
|
||||||
return putOpenChannel(chanBucket, channel)
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the in-memory representation to keep it in sync with the DB.
|
|
||||||
c.chanStatus = status
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// putChannel serializes, and stores the current state of the channel in its
|
|
||||||
// entirety.
|
|
||||||
func putOpenChannel(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
// First, we'll write out all the relatively static fields, that are
|
|
||||||
// decided upon initial channel creation.
|
|
||||||
if err := putChanInfo(chanBucket, channel); err != nil {
|
|
||||||
return fmt.Errorf("unable to store chan info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the static channel info written out, we'll now write out the
|
|
||||||
// current commitment state for both parties.
|
|
||||||
if err := putChanCommitments(chanBucket, channel); err != nil {
|
|
||||||
return fmt.Errorf("unable to store chan commitments: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll write out the revocation state for both parties
|
|
||||||
// within a distinct key space.
|
|
||||||
if err := putChanRevocationState(chanBucket, channel); err != nil {
|
|
||||||
return fmt.Errorf("unable to store chan revocations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchOpenChannel retrieves, and deserializes (including decrypting
|
|
||||||
// sensitive) the complete channel currently active with the passed nodeID.
|
|
||||||
func fetchOpenChannel(chanBucket *bbolt.Bucket,
|
|
||||||
chanPoint *wire.OutPoint) (*OpenChannel, error) {
|
|
||||||
|
|
||||||
channel := &OpenChannel{
|
|
||||||
FundingOutpoint: *chanPoint,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First, we'll read all the static information that changes less
|
|
||||||
// frequently from disk.
|
|
||||||
if err := fetchChanInfo(chanBucket, channel); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to fetch chan info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the static information read, we'll now read the current
|
|
||||||
// commitment state for both sides of the channel.
|
|
||||||
if err := fetchChanCommitments(chanBucket, channel); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to fetch chan commitments: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll retrieve the current revocation state so we can
|
|
||||||
// properly
|
|
||||||
if err := fetchChanRevocationState(chanBucket, channel); err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to fetch chan revocations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.Packager = NewChannelPackager(channel.ShortChannelID)
|
|
||||||
|
|
||||||
return channel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncPending writes the contents of the channel to the database while it's in
|
|
||||||
// the pending (waiting for funding confirmation) state. The IsPending flag
|
|
||||||
// will be set to true. When the channel's funding transaction is confirmed,
|
|
||||||
// the channel should be marked as "open" and the IsPending flag set to false.
|
|
||||||
// Note that this function also creates a LinkNode relationship between this
|
|
||||||
// newly created channel and a new LinkNode instance. This allows listing all
|
|
||||||
// channels in the database globally, or according to the LinkNode they were
|
|
||||||
// created with.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type
|
|
||||||
// that includes service bits.
|
|
||||||
func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
c.FundingBroadcastHeight = pendingHeight
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return syncNewChannel(tx, c, []net.Addr{addr})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// syncNewChannel will write the passed channel to disk, and also create a
|
|
||||||
// LinkNode (if needed) for the channel peer.
|
|
||||||
func syncNewChannel(tx *bbolt.Tx, c *OpenChannel, addrs []net.Addr) error {
|
|
||||||
// First, sync all the persistent channel state to disk.
|
|
||||||
if err := c.fullSync(tx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeInfoBucket, err := tx.CreateBucketIfNotExists(nodeInfoBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a LinkNode for this identity public key already exists,
|
|
||||||
// then we can exit early.
|
|
||||||
nodePub := c.IdentityPub.SerializeCompressed()
|
|
||||||
if nodeInfoBucket.Get(nodePub) != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we need to establish a (possibly) new LinkNode relationship
|
|
||||||
// for this channel. The LinkNode metadata contains reachability,
|
|
||||||
// up-time, and service bits related information.
|
|
||||||
linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...)
|
|
||||||
|
|
||||||
// TODO(roasbeef): do away with link node all together?
|
|
||||||
|
|
||||||
return putLinkNode(nodeInfoBucket, linkNode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateCommitment updates the commitment state for the specified party
|
|
||||||
// (remote or local). The commitment stat completely describes the balance
|
|
||||||
// state at this point in the commitment chain. This method its to be called on
|
|
||||||
// two occasions: when we revoke our prior commitment state, and when the
|
|
||||||
// remote party revokes their prior commitment state.
|
|
||||||
func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
// If this is a restored channel, then we want to avoid mutating the
|
|
||||||
// state as all, as it's impossible to do so in a protocol compliant
|
|
||||||
// manner.
|
|
||||||
if c.hasChanStatus(ChanStatusRestored) {
|
|
||||||
return ErrNoRestoredChannelMutation
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel is marked as borked, then for safety reasons,
|
|
||||||
// we shouldn't attempt any further updates.
|
|
||||||
isBorked, err := c.isBorked(chanBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if isBorked {
|
|
||||||
return ErrChanBorked
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = putChanInfo(chanBucket, c); err != nil {
|
|
||||||
return fmt.Errorf("unable to store chan info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the proper bucket fetched, we'll now write the latest
|
|
||||||
// commitment state to disk for the target party.
|
|
||||||
err = putChanCommitment(
|
|
||||||
chanBucket, newCommitment, true,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to store chan "+
|
|
||||||
"revocations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.LocalCommitment = *newCommitment
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are
|
// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are
|
||||||
// contained within ChannelDeltas which encode the current state of the
|
// contained within ChannelDeltas which encode the current state of the
|
||||||
// commitment between state updates.
|
// commitment between state updates.
|
||||||
@ -1247,101 +476,6 @@ type HTLC struct {
|
|||||||
LogIndex uint64
|
LogIndex uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// SerializeHtlcs writes out the passed set of HTLC's into the passed writer
|
|
||||||
// using the current default on-disk serialization format.
|
|
||||||
//
|
|
||||||
// NOTE: This API is NOT stable, the on-disk format will likely change in the
|
|
||||||
// future.
|
|
||||||
func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
|
|
||||||
numHtlcs := uint16(len(htlcs))
|
|
||||||
if err := WriteElement(b, numHtlcs); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, htlc := range htlcs {
|
|
||||||
if err := WriteElements(b,
|
|
||||||
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
|
|
||||||
htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:],
|
|
||||||
htlc.HtlcIndex, htlc.LogIndex,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed
|
|
||||||
// io.Reader. The bytes within the passed reader MUST have been previously
|
|
||||||
// written to using the SerializeHtlcs function.
|
|
||||||
//
|
|
||||||
// NOTE: This API is NOT stable, the on-disk format will likely change in the
|
|
||||||
// future.
|
|
||||||
func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
|
|
||||||
var numHtlcs uint16
|
|
||||||
if err := ReadElement(r, &numHtlcs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var htlcs []HTLC
|
|
||||||
if numHtlcs == 0 {
|
|
||||||
return htlcs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
htlcs = make([]HTLC, numHtlcs)
|
|
||||||
for i := uint16(0); i < numHtlcs; i++ {
|
|
||||||
if err := ReadElements(r,
|
|
||||||
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
|
|
||||||
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
|
|
||||||
&htlcs[i].Incoming, &htlcs[i].OnionBlob,
|
|
||||||
&htlcs[i].HtlcIndex, &htlcs[i].LogIndex,
|
|
||||||
); err != nil {
|
|
||||||
return htlcs, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return htlcs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy returns a full copy of the target HTLC.
|
|
||||||
func (h *HTLC) Copy() HTLC {
|
|
||||||
clone := HTLC{
|
|
||||||
Incoming: h.Incoming,
|
|
||||||
Amt: h.Amt,
|
|
||||||
RefundTimeout: h.RefundTimeout,
|
|
||||||
OutputIndex: h.OutputIndex,
|
|
||||||
}
|
|
||||||
copy(clone.Signature[:], h.Signature)
|
|
||||||
copy(clone.RHash[:], h.RHash[:])
|
|
||||||
|
|
||||||
return clone
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogUpdate represents a pending update to the remote commitment chain. The
|
|
||||||
// log update may be an add, fail, or settle entry. We maintain this data in
|
|
||||||
// order to be able to properly retransmit our proposed
|
|
||||||
// state if necessary.
|
|
||||||
type LogUpdate struct {
|
|
||||||
// LogIndex is the log index of this proposed commitment update entry.
|
|
||||||
LogIndex uint64
|
|
||||||
|
|
||||||
// UpdateMsg is the update message that was included within the our
|
|
||||||
// local update log. The LogIndex value denotes the log index of this
|
|
||||||
// update which will be used when restoring our local update log if
|
|
||||||
// we're left with a dangling update on restart.
|
|
||||||
UpdateMsg lnwire.Message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode writes a log update to the provided io.Writer.
|
|
||||||
func (l *LogUpdate) Encode(w io.Writer) error {
|
|
||||||
return WriteElements(w, l.LogIndex, l.UpdateMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads a log update from the provided io.Reader.
|
|
||||||
func (l *LogUpdate) Decode(r io.Reader) error {
|
|
||||||
return ReadElements(r, &l.LogIndex, &l.UpdateMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives
|
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives
|
||||||
// from the switch, and is used to purge our in-memory state of HTLCs that have
|
// from the switch, and is used to purge our in-memory state of HTLCs that have
|
||||||
// already been processed by a link. Two list of CircuitKeys are included in
|
// already been processed by a link. Two list of CircuitKeys are included in
|
||||||
@ -1360,723 +494,20 @@ type CircuitKey struct {
|
|||||||
HtlcID uint64
|
HtlcID uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBytes deserializes the given bytes into this CircuitKey.
|
|
||||||
func (k *CircuitKey) SetBytes(bs []byte) error {
|
|
||||||
if len(bs) != 16 {
|
|
||||||
return ErrInvalidCircuitKeyLen
|
|
||||||
}
|
|
||||||
|
|
||||||
k.ChanID = lnwire.NewShortChanIDFromInt(
|
|
||||||
binary.BigEndian.Uint64(bs[:8]))
|
|
||||||
k.HtlcID = binary.BigEndian.Uint64(bs[8:])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes returns the serialized bytes for this circuit key.
|
|
||||||
func (k CircuitKey) Bytes() []byte {
|
|
||||||
var bs = make([]byte, 16)
|
|
||||||
binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64())
|
|
||||||
binary.BigEndian.PutUint64(bs[8:], k.HtlcID)
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode writes a CircuitKey to the provided io.Writer.
|
|
||||||
func (k *CircuitKey) Encode(w io.Writer) error {
|
|
||||||
var scratch [16]byte
|
|
||||||
binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64())
|
|
||||||
binary.BigEndian.PutUint64(scratch[8:], k.HtlcID)
|
|
||||||
|
|
||||||
_, err := w.Write(scratch[:])
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads a CircuitKey from the provided io.Reader.
|
|
||||||
func (k *CircuitKey) Decode(r io.Reader) error {
|
|
||||||
var scratch [16]byte
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, scratch[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
k.ChanID = lnwire.NewShortChanIDFromInt(
|
|
||||||
binary.BigEndian.Uint64(scratch[:8]))
|
|
||||||
k.HtlcID = binary.BigEndian.Uint64(scratch[8:])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the CircuitKey.
|
// String returns a string representation of the CircuitKey.
|
||||||
func (k CircuitKey) String() string {
|
func (k CircuitKey) String() string {
|
||||||
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID)
|
return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CommitDiff represents the delta needed to apply the state transition between
|
|
||||||
// two subsequent commitment states. Given state N and state N+1, one is able
|
|
||||||
// to apply the set of messages contained within the CommitDiff to N to arrive
|
|
||||||
// at state N+1. Each time a new commitment is extended, we'll write a new
|
|
||||||
// commitment (along with the full commitment state) to disk so we can
|
|
||||||
// re-transmit the state in the case of a connection loss or message drop.
|
|
||||||
type CommitDiff struct {
|
|
||||||
// ChannelCommitment is the full commitment state that one would arrive
|
|
||||||
// at by applying the set of messages contained in the UpdateDiff to
|
|
||||||
// the prior accepted commitment.
|
|
||||||
Commitment ChannelCommitment
|
|
||||||
|
|
||||||
// LogUpdates is the set of messages sent prior to the commitment state
|
|
||||||
// transition in question. Upon reconnection, if we detect that they
|
|
||||||
// don't have the commitment, then we re-send this along with the
|
|
||||||
// proper signature.
|
|
||||||
LogUpdates []LogUpdate
|
|
||||||
|
|
||||||
// CommitSig is the exact CommitSig message that should be sent after
|
|
||||||
// the set of LogUpdates above has been retransmitted. The signatures
|
|
||||||
// within this message should properly cover the new commitment state
|
|
||||||
// and also the HTLC's within the new commitment state.
|
|
||||||
CommitSig *lnwire.CommitSig
|
|
||||||
|
|
||||||
// OpenedCircuitKeys is a set of unique identifiers for any downstream
|
|
||||||
// Add packets included in this commitment txn. After a restart, this
|
|
||||||
// set of htlcs is acked from the link's incoming mailbox to ensure
|
|
||||||
// there isn't an attempt to re-add them to this commitment txn.
|
|
||||||
OpenedCircuitKeys []CircuitKey
|
|
||||||
|
|
||||||
// ClosedCircuitKeys records the unique identifiers for any settle/fail
|
|
||||||
// packets that were resolved by this commitment txn. After a restart,
|
|
||||||
// this is used to ensure those circuits are removed from the circuit
|
|
||||||
// map, and the downstream packets in the link's mailbox are removed.
|
|
||||||
ClosedCircuitKeys []CircuitKey
|
|
||||||
|
|
||||||
// AddAcks specifies the locations (commit height, pkg index) of any
|
|
||||||
// Adds that were failed/settled in this commit diff. This will ack
|
|
||||||
// entries in *this* channel's forwarding packages.
|
|
||||||
//
|
|
||||||
// NOTE: This value is not serialized, it is used to atomically mark the
|
|
||||||
// resolution of adds, such that they will not be reprocessed after a
|
|
||||||
// restart.
|
|
||||||
AddAcks []AddRef
|
|
||||||
|
|
||||||
// SettleFailAcks specifies the locations (chan id, commit height, pkg
|
|
||||||
// index) of any Settles or Fails that were locked into this commit
|
|
||||||
// diff, and originate from *another* channel, i.e. the outgoing link.
|
|
||||||
//
|
|
||||||
// NOTE: This value is not serialized, it is used to atomically acks
|
|
||||||
// settles and fails from the forwarding packages of other channels,
|
|
||||||
// such that they will not be reforwarded internally after a restart.
|
|
||||||
SettleFailAcks []SettleFailRef
|
|
||||||
}
|
|
||||||
|
|
||||||
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
|
|
||||||
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := diff.CommitSig.Encode(w, 0); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
numUpdates := uint16(len(diff.LogUpdates))
|
|
||||||
if err := binary.Write(w, byteOrder, numUpdates); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, diff := range diff.LogUpdates {
|
|
||||||
err := WriteElements(w, diff.LogIndex, diff.UpdateMsg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
numOpenRefs := uint16(len(diff.OpenedCircuitKeys))
|
|
||||||
if err := binary.Write(w, byteOrder, numOpenRefs); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, openRef := range diff.OpenedCircuitKeys {
|
|
||||||
err := WriteElements(w, openRef.ChanID, openRef.HtlcID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
numClosedRefs := uint16(len(diff.ClosedCircuitKeys))
|
|
||||||
if err := binary.Write(w, byteOrder, numClosedRefs); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, closedRef := range diff.ClosedCircuitKeys {
|
|
||||||
err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
|
|
||||||
var (
|
|
||||||
d CommitDiff
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
d.Commitment, err = deserializeChanCommit(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.CommitSig = &lnwire.CommitSig{}
|
|
||||||
if err := d.CommitSig.Decode(r, 0); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var numUpdates uint16
|
|
||||||
if err := binary.Read(r, byteOrder, &numUpdates); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.LogUpdates = make([]LogUpdate, numUpdates)
|
|
||||||
for i := 0; i < int(numUpdates); i++ {
|
|
||||||
err := ReadElements(r,
|
|
||||||
&d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var numOpenRefs uint16
|
|
||||||
if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs)
|
|
||||||
for i := 0; i < int(numOpenRefs); i++ {
|
|
||||||
err := ReadElements(r,
|
|
||||||
&d.OpenedCircuitKeys[i].ChanID,
|
|
||||||
&d.OpenedCircuitKeys[i].HtlcID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var numClosedRefs uint16
|
|
||||||
if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs)
|
|
||||||
for i := 0; i < int(numClosedRefs); i++ {
|
|
||||||
err := ReadElements(r,
|
|
||||||
&d.ClosedCircuitKeys[i].ChanID,
|
|
||||||
&d.ClosedCircuitKeys[i].HtlcID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &d, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AppendRemoteCommitChain appends a new CommitDiff to the end of the
|
|
||||||
// commitment chain for the remote party. This method is to be used once we
|
|
||||||
// have prepared a new commitment state for the remote party, but before we
|
|
||||||
// transmit it to the remote party. The contents of the argument should be
|
|
||||||
// sufficient to retransmit the updates and signature needed to reconstruct the
|
|
||||||
// state in full, in the case that we need to retransmit.
|
|
||||||
func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
// If this is a restored channel, then we want to avoid mutating the
|
|
||||||
// state at all, as it's impossible to do so in a protocol compliant
|
|
||||||
// manner.
|
|
||||||
if c.hasChanStatus(ChanStatusRestored) {
|
|
||||||
return ErrNoRestoredChannelMutation
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
// First, we'll grab the writable bucket where this channel's
|
|
||||||
// data resides.
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel is marked as borked, then for safety reasons,
|
|
||||||
// we shouldn't attempt any further updates.
|
|
||||||
isBorked, err := c.isBorked(chanBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if isBorked {
|
|
||||||
return ErrChanBorked
|
|
||||||
}
|
|
||||||
|
|
||||||
// Any outgoing settles and fails necessarily have a
|
|
||||||
// corresponding adds in this channel's forwarding packages.
|
|
||||||
// Mark all of these as being fully processed in our forwarding
|
|
||||||
// package, which prevents us from reprocessing them after
|
|
||||||
// startup.
|
|
||||||
err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Additionally, we ack from any fails or settles that are
|
|
||||||
// persisted in another channel's forwarding package. This
|
|
||||||
// prevents the same fails and settles from being retransmitted
|
|
||||||
// after restarts. The actual fail or settle we need to
|
|
||||||
// propagate to the remote party is now in the commit diff.
|
|
||||||
err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(roasbeef): use seqno to derive key for later LCP
|
|
||||||
|
|
||||||
// With the bucket retrieved, we'll now serialize the commit
|
|
||||||
// diff itself, and write it to disk.
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializeCommitDiff(&b, diff); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return chanBucket.Put(commitDiffKey, b.Bytes())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteCommitChainTip returns the "tip" of the current remote commitment
|
|
||||||
// chain. This value will be non-nil iff, we've created a new commitment for
|
|
||||||
// the remote party that they haven't yet ACK'd. In this case, their commitment
|
|
||||||
// chain will have a length of two: their current unrevoked commitment, and
|
|
||||||
// this new pending commitment. Once they revoked their prior state, we'll swap
|
|
||||||
// these pointers, causing the tip and the tail to point to the same entry.
|
|
||||||
func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) {
|
|
||||||
var cd *CommitDiff
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
switch err {
|
|
||||||
case nil:
|
|
||||||
case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound:
|
|
||||||
return ErrNoPendingCommit
|
|
||||||
default:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tipBytes := chanBucket.Get(commitDiffKey)
|
|
||||||
if tipBytes == nil {
|
|
||||||
return ErrNoPendingCommit
|
|
||||||
}
|
|
||||||
|
|
||||||
tipReader := bytes.NewReader(tipBytes)
|
|
||||||
dcd, err := deserializeCommitDiff(tipReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
cd = dcd
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cd, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertNextRevocation inserts the _next_ commitment point (revocation) into
|
|
||||||
// the database, and also modifies the internal RemoteNextRevocation attribute
|
|
||||||
// to point to the passed key. This method is to be using during final channel
|
|
||||||
// set up, _after_ the channel has been fully confirmed.
|
|
||||||
//
|
|
||||||
// NOTE: If this method isn't called, then the target channel won't be able to
|
|
||||||
// propose new states for the commitment state of the remote party.
|
|
||||||
func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
c.RemoteNextRevocation = revKey
|
|
||||||
|
|
||||||
err := c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return putChanRevocationState(chanBucket, c)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AdvanceCommitChainTail records the new state transition within an on-disk
|
|
||||||
// append-only log which records all state transitions by the remote peer. In
|
|
||||||
// the case of an uncooperative broadcast of a prior state by the remote peer,
|
|
||||||
// this log can be consulted in order to reconstruct the state needed to
|
|
||||||
// rectify the situation. This method will add the current commitment for the
|
|
||||||
// remote party to the revocation log, and promote the current pending
|
|
||||||
// commitment to the current remote commitment.
|
|
||||||
func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
// If this is a restored channel, then we want to avoid mutating the
|
|
||||||
// state at all, as it's impossible to do so in a protocol compliant
|
|
||||||
// manner.
|
|
||||||
if c.hasChanStatus(ChanStatusRestored) {
|
|
||||||
return ErrNoRestoredChannelMutation
|
|
||||||
}
|
|
||||||
|
|
||||||
var newRemoteCommit *ChannelCommitment
|
|
||||||
|
|
||||||
err := c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel is marked as borked, then for safety reasons,
|
|
||||||
// we shouldn't attempt any further updates.
|
|
||||||
isBorked, err := c.isBorked(chanBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if isBorked {
|
|
||||||
return ErrChanBorked
|
|
||||||
}
|
|
||||||
|
|
||||||
// Persist the latest preimage state to disk as the remote peer
|
|
||||||
// has just added to our local preimage store, and given us a
|
|
||||||
// new pending revocation key.
|
|
||||||
if err := putChanRevocationState(chanBucket, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the current preimage producer/store state updated,
|
|
||||||
// append a new log entry recording this the delta of this
|
|
||||||
// state transition.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): could make the deltas relative, would save
|
|
||||||
// space, but then tradeoff for more disk-seeks to recover the
|
|
||||||
// full state.
|
|
||||||
logKey := revocationLogBucket
|
|
||||||
logBucket, err := chanBucket.CreateBucketIfNotExists(logKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Before we append this revoked state to the revocation log,
|
|
||||||
// we'll swap out what's currently the tail of the commit tip,
|
|
||||||
// with the current locked-in commitment for the remote party.
|
|
||||||
tipBytes := chanBucket.Get(commitDiffKey)
|
|
||||||
tipReader := bytes.NewReader(tipBytes)
|
|
||||||
newCommit, err := deserializeCommitDiff(tipReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = putChanCommitment(
|
|
||||||
chanBucket, &newCommit.Commitment, false,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := chanBucket.Delete(commitDiffKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the commitment pointer swapped, we can now add the
|
|
||||||
// revoked (prior) state to the revocation log.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): store less
|
|
||||||
err = appendChannelLogEntry(logBucket, &c.RemoteCommitment)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lastly, we write the forwarding package to disk so that we
|
|
||||||
// can properly recover from failures and reforward HTLCs that
|
|
||||||
// have not received a corresponding settle/fail.
|
|
||||||
if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
newRemoteCommit = &newCommit.Commitment
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the db transaction complete, we'll swap over the in-memory
|
|
||||||
// pointer of the new remote commitment, which was previously the tip
|
|
||||||
// of the commit chain.
|
|
||||||
c.RemoteCommitment = *newRemoteCommit
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure
|
|
||||||
// this always returns the next index that has been not been allocated, this
|
|
||||||
// will first try to examine any pending commitments, before falling back to the
|
|
||||||
// last locked-in local commitment.
|
|
||||||
func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) {
|
|
||||||
// First, load the most recent commit diff that we initiated for the
|
|
||||||
// remote party. If no pending commit is found, this is not treated as
|
|
||||||
// a critical error, since we can always fall back.
|
|
||||||
pendingRemoteCommit, err := c.RemoteCommitChainTip()
|
|
||||||
if err != nil && err != ErrNoPendingCommit {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a pending commit was found, its local htlc index will be at least
|
|
||||||
// as large as the one on our local commitment.
|
|
||||||
if pendingRemoteCommit != nil {
|
|
||||||
return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, fallback to using the local htlc index of our commitment.
|
|
||||||
return c.LocalCommitment.LocalHtlcIndex, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
|
|
||||||
// processed, and returns their deserialized log updates in map indexed by the
|
|
||||||
// remote commitment height at which the updates were locked in.
|
|
||||||
func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
var fwdPkgs []*FwdPkg
|
|
||||||
if err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
fwdPkgs, err = c.Packager.LoadFwdPkgs(tx)
|
|
||||||
return err
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fwdPkgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs
|
|
||||||
// indicating that a response to this Add has been committed to the remote party.
|
|
||||||
// Doing so will prevent these Add HTLCs from being reforwarded internally.
|
|
||||||
func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return c.Packager.AckAddHtlcs(tx, addRefs...)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AckSettleFails updates the SettleFailFilter containing any of the provided
|
|
||||||
// SettleFailRefs, indicating that the response has been delivered to the
|
|
||||||
// incoming link, corresponding to a particular AddRef. Doing so will prevent
|
|
||||||
// the responses from being retransmitted internally.
|
|
||||||
func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return c.Packager.AckSettleFails(tx, settleFailRefs...)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFwdFilter atomically sets the forwarding filter for the forwarding package
|
|
||||||
// identified by `height`.
|
|
||||||
func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return c.Packager.SetFwdFilter(tx, height, fwdFilter)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveFwdPkg atomically removes a forwarding package specified by the remote
|
|
||||||
// commitment height.
|
|
||||||
//
|
|
||||||
// NOTE: This method should only be called on packages marked FwdStateCompleted.
|
|
||||||
func (c *OpenChannel) RemoveFwdPkg(height uint64) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return c.Packager.RemovePkg(tx, height)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// RevocationLogTail returns the "tail", or the end of the current revocation
|
|
||||||
// log. This entry represents the last previous state for the remote node's
|
|
||||||
// commitment chain. The ChannelDelta returned by this method will always lag
|
|
||||||
// one state behind the most current (unrevoked) state of the remote node's
|
|
||||||
// commitment chain.
|
|
||||||
func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
// If we haven't created any state updates yet, then we'll exit early as
|
|
||||||
// there's nothing to be found on disk in the revocation bucket.
|
|
||||||
if c.RemoteCommitment.CommitHeight == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var commit ChannelCommitment
|
|
||||||
if err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logBucket := chanBucket.Bucket(revocationLogBucket)
|
|
||||||
if logBucket == nil {
|
|
||||||
return ErrNoPastDeltas
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once we have the bucket that stores the revocation log from
|
|
||||||
// this channel, we'll jump to the _last_ key in bucket. As we
|
|
||||||
// store the update number on disk in a big-endian format,
|
|
||||||
// this will retrieve the latest entry.
|
|
||||||
cursor := logBucket.Cursor()
|
|
||||||
_, tailLogEntry := cursor.Last()
|
|
||||||
logEntryReader := bytes.NewReader(tailLogEntry)
|
|
||||||
|
|
||||||
// Once we have the entry, we'll decode it into the channel
|
|
||||||
// delta pointer we created above.
|
|
||||||
var dbErr error
|
|
||||||
commit, dbErr = deserializeChanCommit(logEntryReader)
|
|
||||||
if dbErr != nil {
|
|
||||||
return dbErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &commit, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommitmentHeight returns the current commitment height. The commitment
|
|
||||||
// height represents the number of updates to the commitment state to date.
|
|
||||||
// This value is always monotonically increasing. This method is provided in
|
|
||||||
// order to allow multiple instances of a particular open channel to obtain a
|
|
||||||
// consistent view of the number of channel updates to date.
|
|
||||||
func (c *OpenChannel) CommitmentHeight() (uint64, error) {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
var height uint64
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// Get the bucket dedicated to storing the metadata for open
|
|
||||||
// channels.
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
commit, err := fetchChanCommitment(chanBucket, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
height = commit.CommitHeight
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return height, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindPreviousState scans through the append-only log in an attempt to recover
|
|
||||||
// the previous channel state indicated by the update number. This method is
|
|
||||||
// intended to be used for obtaining the relevant data needed to claim all
|
|
||||||
// funds rightfully spendable in the case of an on-chain broadcast of the
|
|
||||||
// commitment transaction.
|
|
||||||
func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, error) {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
var commit ChannelCommitment
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logBucket := chanBucket.Bucket(revocationLogBucket)
|
|
||||||
if logBucket == nil {
|
|
||||||
return ErrNoPastDeltas
|
|
||||||
}
|
|
||||||
|
|
||||||
c, err := fetchChannelLogEntry(logBucket, updateNum)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
commit = c
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &commit, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClosureType is an enum like structure that details exactly _how_ a channel
|
// ClosureType is an enum like structure that details exactly _how_ a channel
|
||||||
// was closed. Three closure types are currently possible: none, cooperative,
|
// was closed. Three closure types are currently possible: none, cooperative,
|
||||||
// local force close, remote force close, and (remote) breach.
|
// local force close, remote force close, and (remote) breach.
|
||||||
type ClosureType uint8
|
type ClosureType uint8
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// CooperativeClose indicates that a channel has been closed
|
|
||||||
// cooperatively. This means that both channel peers were online and
|
|
||||||
// signed a new transaction paying out the settled balance of the
|
|
||||||
// contract.
|
|
||||||
CooperativeClose ClosureType = 0
|
|
||||||
|
|
||||||
// LocalForceClose indicates that we have unilaterally broadcast our
|
|
||||||
// current commitment state on-chain.
|
|
||||||
LocalForceClose ClosureType = 1
|
|
||||||
|
|
||||||
// RemoteForceClose indicates that the remote peer has unilaterally
|
// RemoteForceClose indicates that the remote peer has unilaterally
|
||||||
// broadcast their current commitment state on-chain.
|
// broadcast their current commitment state on-chain.
|
||||||
RemoteForceClose ClosureType = 4
|
RemoteForceClose ClosureType = 4
|
||||||
|
|
||||||
// BreachClose indicates that the remote peer attempted to broadcast a
|
|
||||||
// prior _revoked_ channel state.
|
|
||||||
BreachClose ClosureType = 2
|
|
||||||
|
|
||||||
// FundingCanceled indicates that the channel never was fully opened
|
|
||||||
// before it was marked as closed in the database. This can happen if
|
|
||||||
// we or the remote fail at some point during the opening workflow, or
|
|
||||||
// we timeout waiting for the funding transaction to be confirmed.
|
|
||||||
FundingCanceled ClosureType = 3
|
|
||||||
|
|
||||||
// Abandoned indicates that the channel state was removed without
|
|
||||||
// any further actions. This is intended to clean up unusable
|
|
||||||
// channels during development.
|
|
||||||
Abandoned ClosureType = 5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChannelCloseSummary contains the final state of a channel at the point it
|
// ChannelCloseSummary contains the final state of a channel at the point it
|
||||||
@ -2160,214 +591,6 @@ type ChannelCloseSummary struct {
|
|||||||
LastChanSyncMsg *lnwire.ChannelReestablish
|
LastChanSyncMsg *lnwire.ChannelReestablish
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseChannel closes a previously active Lightning channel. Closing a channel
|
|
||||||
// entails deleting all saved state within the database concerning this
|
|
||||||
// channel. This method also takes a struct that summarizes the state of the
|
|
||||||
// channel at closing, this compact representation will be the only component
|
|
||||||
// of a channel left over after a full closing.
|
|
||||||
func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary) error {
|
|
||||||
c.Lock()
|
|
||||||
defer c.Unlock()
|
|
||||||
|
|
||||||
return c.Db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
openChanBucket := tx.Bucket(openChannelBucket)
|
|
||||||
if openChanBucket == nil {
|
|
||||||
return ErrNoChanDBExists
|
|
||||||
}
|
|
||||||
|
|
||||||
nodePub := c.IdentityPub.SerializeCompressed()
|
|
||||||
nodeChanBucket := openChanBucket.Bucket(nodePub)
|
|
||||||
if nodeChanBucket == nil {
|
|
||||||
return ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
chainBucket := nodeChanBucket.Bucket(c.ChainHash[:])
|
|
||||||
if chainBucket == nil {
|
|
||||||
return ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
var chanPointBuf bytes.Buffer
|
|
||||||
err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chanBucket := chainBucket.Bucket(chanPointBuf.Bytes())
|
|
||||||
if chanBucket == nil {
|
|
||||||
return ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
// Before we delete the channel state, we'll read out the full
|
|
||||||
// details, as we'll also store portions of this information
|
|
||||||
// for record keeping.
|
|
||||||
chanState, err := fetchOpenChannel(
|
|
||||||
chanBucket, &c.FundingOutpoint,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the index to this channel has been deleted, purge
|
|
||||||
// the remaining channel metadata from the database.
|
|
||||||
err = deleteOpenChannel(chanBucket, chanPointBuf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the base channel data deleted, attempt to delete the
|
|
||||||
// information stored within the revocation log.
|
|
||||||
logBucket := chanBucket.Bucket(revocationLogBucket)
|
|
||||||
if logBucket != nil {
|
|
||||||
err = chanBucket.DeleteBucket(revocationLogBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = chainBucket.DeleteBucket(chanPointBuf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, create a summary of this channel in the closed
|
|
||||||
// channel bucket for this node.
|
|
||||||
return putChannelCloseSummary(
|
|
||||||
tx, chanPointBuf.Bytes(), summary, chanState,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelSnapshot is a frozen snapshot of the current channel state. A
|
|
||||||
// snapshot is detached from the original channel that generated it, providing
|
|
||||||
// read-only access to the current or prior state of an active channel.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): remove all together? pretty much just commitment
|
|
||||||
type ChannelSnapshot struct {
|
|
||||||
// RemoteIdentity is the identity public key of the remote node that we
|
|
||||||
// are maintaining the open channel with.
|
|
||||||
RemoteIdentity btcec.PublicKey
|
|
||||||
|
|
||||||
// ChanPoint is the outpoint that created the channel. This output is
|
|
||||||
// found within the funding transaction and uniquely identified the
|
|
||||||
// channel on the resident chain.
|
|
||||||
ChannelPoint wire.OutPoint
|
|
||||||
|
|
||||||
// ChainHash is the genesis hash of the chain that the channel resides
|
|
||||||
// within.
|
|
||||||
ChainHash chainhash.Hash
|
|
||||||
|
|
||||||
// Capacity is the total capacity of the channel.
|
|
||||||
Capacity btcutil.Amount
|
|
||||||
|
|
||||||
// TotalMSatSent is the total number of milli-satoshis we've sent
|
|
||||||
// within this channel.
|
|
||||||
TotalMSatSent lnwire.MilliSatoshi
|
|
||||||
|
|
||||||
// TotalMSatReceived is the total number of milli-satoshis we've
|
|
||||||
// received within this channel.
|
|
||||||
TotalMSatReceived lnwire.MilliSatoshi
|
|
||||||
|
|
||||||
// ChannelCommitment is the current up-to-date commitment for the
|
|
||||||
// target channel.
|
|
||||||
ChannelCommitment
|
|
||||||
}
|
|
||||||
|
|
||||||
// Snapshot returns a read-only snapshot of the current channel state. This
|
|
||||||
// snapshot includes information concerning the current settled balance within
|
|
||||||
// the channel, metadata detailing total flows, and any outstanding HTLCs.
|
|
||||||
func (c *OpenChannel) Snapshot() *ChannelSnapshot {
|
|
||||||
c.RLock()
|
|
||||||
defer c.RUnlock()
|
|
||||||
|
|
||||||
localCommit := c.LocalCommitment
|
|
||||||
snapshot := &ChannelSnapshot{
|
|
||||||
RemoteIdentity: *c.IdentityPub,
|
|
||||||
ChannelPoint: c.FundingOutpoint,
|
|
||||||
Capacity: c.Capacity,
|
|
||||||
TotalMSatSent: c.TotalMSatSent,
|
|
||||||
TotalMSatReceived: c.TotalMSatReceived,
|
|
||||||
ChainHash: c.ChainHash,
|
|
||||||
ChannelCommitment: ChannelCommitment{
|
|
||||||
LocalBalance: localCommit.LocalBalance,
|
|
||||||
RemoteBalance: localCommit.RemoteBalance,
|
|
||||||
CommitHeight: localCommit.CommitHeight,
|
|
||||||
CommitFee: localCommit.CommitFee,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy over the current set of HTLCs to ensure the caller can't mutate
|
|
||||||
// our internal state.
|
|
||||||
snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs))
|
|
||||||
for i, h := range localCommit.Htlcs {
|
|
||||||
snapshot.Htlcs[i] = h.Copy()
|
|
||||||
}
|
|
||||||
|
|
||||||
return snapshot
|
|
||||||
}
|
|
||||||
|
|
||||||
// LatestCommitments returns the two latest commitments for both the local and
|
|
||||||
// remote party. These commitments are read from disk to ensure that only the
|
|
||||||
// latest fully committed state is returned. The first commitment returned is
|
|
||||||
// the local commitment, and the second returned is the remote commitment.
|
|
||||||
func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) {
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fetchChanCommitments(chanBucket, c)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &c.LocalCommitment, &c.RemoteCommitment, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoteRevocationStore returns the most up to date commitment version of the
|
|
||||||
// revocation storage tree for the remote party. This method can be used when
|
|
||||||
// acting on a possible contract breach to ensure, that the caller has the most
|
|
||||||
// up to date information required to deliver justice.
|
|
||||||
func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) {
|
|
||||||
err := c.Db.View(func(tx *bbolt.Tx) error {
|
|
||||||
chanBucket, err := fetchChanBucket(
|
|
||||||
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fetchChanRevocationState(chanBucket, c)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.RevocationStore, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func putChannelCloseSummary(tx *bbolt.Tx, chanID []byte,
|
|
||||||
summary *ChannelCloseSummary, lastChanState *OpenChannel) error {
|
|
||||||
|
|
||||||
closedChanBucket, err := tx.CreateBucketIfNotExists(closedChannelBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation
|
|
||||||
summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation
|
|
||||||
summary.LocalChanConfig = lastChanState.LocalChanCfg
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializeChannelCloseSummary(&b, summary); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return closedChanBucket.Put(chanID, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
|
func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
|
||||||
err := WriteElements(w,
|
err := WriteElements(w,
|
||||||
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
|
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
|
||||||
@ -2517,113 +740,6 @@ func writeChanConfig(b io.Writer, c *ChannelConfig) error {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func putChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
var w bytes.Buffer
|
|
||||||
if err := WriteElements(&w,
|
|
||||||
channel.ChanType, channel.ChainHash, channel.FundingOutpoint,
|
|
||||||
channel.ShortChannelID, channel.IsPending, channel.IsInitiator,
|
|
||||||
channel.chanStatus, channel.FundingBroadcastHeight,
|
|
||||||
channel.NumConfsRequired, channel.ChannelFlags,
|
|
||||||
channel.IdentityPub, channel.Capacity, channel.TotalMSatSent,
|
|
||||||
channel.TotalMSatReceived,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For single funder channels that we initiated, write the funding txn.
|
|
||||||
if channel.ChanType.IsSingleFunder() && channel.IsInitiator &&
|
|
||||||
!channel.hasChanStatus(ChanStatusRestored) {
|
|
||||||
|
|
||||||
if err := WriteElement(&w, channel.FundingTxn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanBucket.Put(chanInfoKey, w.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func serializeChanCommit(w io.Writer, c *ChannelCommitment) error {
|
|
||||||
if err := WriteElements(w,
|
|
||||||
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
|
|
||||||
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
|
|
||||||
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
|
|
||||||
c.CommitSig,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return SerializeHtlcs(w, c.Htlcs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func putChanCommitment(chanBucket *bbolt.Bucket, c *ChannelCommitment,
|
|
||||||
local bool) error {
|
|
||||||
|
|
||||||
var commitKey []byte
|
|
||||||
if local {
|
|
||||||
commitKey = append(chanCommitmentKey, byte(0x00))
|
|
||||||
} else {
|
|
||||||
commitKey = append(chanCommitmentKey, byte(0x01))
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializeChanCommit(&b, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanBucket.Put(commitKey, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func putChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
// If this is a restored channel, then we don't have any commitments to
|
|
||||||
// write.
|
|
||||||
if channel.hasChanStatus(ChanStatusRestored) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := putChanCommitment(
|
|
||||||
chanBucket, &channel.LocalCommitment, true,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return putChanCommitment(
|
|
||||||
chanBucket, &channel.RemoteCommitment, false,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func putChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
err := WriteElements(
|
|
||||||
&b, channel.RemoteCurrentRevocation, channel.RevocationProducer,
|
|
||||||
channel.RevocationStore,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(roasbeef): don't keep producer on disk
|
|
||||||
|
|
||||||
// If the next revocation is present, which is only the case after the
|
|
||||||
// FundingLocked message has been sent, then we'll write it to disk.
|
|
||||||
if channel.RemoteNextRevocation != nil {
|
|
||||||
err = WriteElements(&b, channel.RemoteNextRevocation)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanBucket.Put(revocationStateKey, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func readChanConfig(b io.Reader, c *ChannelConfig) error {
|
func readChanConfig(b io.Reader, c *ChannelConfig) error {
|
||||||
return ReadElements(b,
|
return ReadElements(b,
|
||||||
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
|
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
|
||||||
@ -2633,185 +749,3 @@ func readChanConfig(b io.Reader, c *ChannelConfig) error {
|
|||||||
&c.HtlcBasePoint,
|
&c.HtlcBasePoint,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
infoBytes := chanBucket.Get(chanInfoKey)
|
|
||||||
if infoBytes == nil {
|
|
||||||
return ErrNoChanInfoFound
|
|
||||||
}
|
|
||||||
r := bytes.NewReader(infoBytes)
|
|
||||||
|
|
||||||
if err := ReadElements(r,
|
|
||||||
&channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint,
|
|
||||||
&channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator,
|
|
||||||
&channel.chanStatus, &channel.FundingBroadcastHeight,
|
|
||||||
&channel.NumConfsRequired, &channel.ChannelFlags,
|
|
||||||
&channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent,
|
|
||||||
&channel.TotalMSatReceived,
|
|
||||||
); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For single funder channels that we initiated, read the funding txn.
|
|
||||||
if channel.ChanType.IsSingleFunder() && channel.IsInitiator &&
|
|
||||||
!channel.hasChanStatus(ChanStatusRestored) {
|
|
||||||
|
|
||||||
if err := ReadElement(r, &channel.FundingTxn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := readChanConfig(r, &channel.LocalChanCfg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.Packager = NewChannelPackager(channel.ShortChannelID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) {
|
|
||||||
var c ChannelCommitment
|
|
||||||
|
|
||||||
err := ReadElements(r,
|
|
||||||
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
|
|
||||||
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
|
|
||||||
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return c, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Htlcs, err = DeserializeHtlcs(r)
|
|
||||||
if err != nil {
|
|
||||||
return c, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchChanCommitment(chanBucket *bbolt.Bucket, local bool) (ChannelCommitment, error) {
|
|
||||||
var commitKey []byte
|
|
||||||
if local {
|
|
||||||
commitKey = append(chanCommitmentKey, byte(0x00))
|
|
||||||
} else {
|
|
||||||
commitKey = append(chanCommitmentKey, byte(0x01))
|
|
||||||
}
|
|
||||||
|
|
||||||
commitBytes := chanBucket.Get(commitKey)
|
|
||||||
if commitBytes == nil {
|
|
||||||
return ChannelCommitment{}, ErrNoCommitmentsFound
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(commitBytes)
|
|
||||||
return deserializeChanCommit(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// If this is a restored channel, then we don't have any commitments to
|
|
||||||
// read.
|
|
||||||
if channel.hasChanStatus(ChanStatusRestored) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error {
|
|
||||||
revBytes := chanBucket.Get(revocationStateKey)
|
|
||||||
if revBytes == nil {
|
|
||||||
return ErrNoRevocationsFound
|
|
||||||
}
|
|
||||||
r := bytes.NewReader(revBytes)
|
|
||||||
|
|
||||||
err := ReadElements(
|
|
||||||
r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer,
|
|
||||||
&channel.RevocationStore,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there aren't any bytes left in the buffer, then we don't yet have
|
|
||||||
// the next remote revocation, so we can exit early here.
|
|
||||||
if r.Len() == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise we'll read the next revocation for the remote party which
|
|
||||||
// is always the last item within the buffer.
|
|
||||||
return ReadElements(r, &channel.RemoteNextRevocation)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deleteOpenChannel(chanBucket *bbolt.Bucket, chanPointBytes []byte) error {
|
|
||||||
|
|
||||||
if err := chanBucket.Delete(chanInfoKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00)))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01)))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := chanBucket.Delete(revocationStateKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff := chanBucket.Get(commitDiffKey); diff != nil {
|
|
||||||
return chanBucket.Delete(commitDiffKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeLogKey converts a uint64 into an 8 byte array.
|
|
||||||
func makeLogKey(updateNum uint64) [8]byte {
|
|
||||||
var key [8]byte
|
|
||||||
byteOrder.PutUint64(key[:], updateNum)
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
func appendChannelLogEntry(log *bbolt.Bucket,
|
|
||||||
commit *ChannelCommitment) error {
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializeChanCommit(&b, commit); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logEntrykey := makeLogKey(commit.CommitHeight)
|
|
||||||
return log.Put(logEntrykey[:], b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchChannelLogEntry(log *bbolt.Bucket,
|
|
||||||
updateNum uint64) (ChannelCommitment, error) {
|
|
||||||
|
|
||||||
logEntrykey := makeLogKey(updateNum)
|
|
||||||
commitBytes := log.Get(logEntrykey[:])
|
|
||||||
if commitBytes == nil {
|
|
||||||
return ChannelCommitment{}, fmt.Errorf("log entry not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
commitReader := bytes.NewReader(commitBytes)
|
|
||||||
return deserializeChanCommit(commitReader)
|
|
||||||
}
|
|
||||||
|
@ -1,50 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
// channelCache is an in-memory cache used to improve the performance of
|
|
||||||
// ChanUpdatesInHorizon. It caches the chan info and edge policies for a
|
|
||||||
// particular channel.
|
|
||||||
type channelCache struct {
|
|
||||||
n int
|
|
||||||
channels map[uint64]ChannelEdge
|
|
||||||
}
|
|
||||||
|
|
||||||
// newChannelCache creates a new channelCache with maximum capacity of n
|
|
||||||
// channels.
|
|
||||||
func newChannelCache(n int) *channelCache {
|
|
||||||
return &channelCache{
|
|
||||||
n: n,
|
|
||||||
channels: make(map[uint64]ChannelEdge),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get returns the channel from the cache, if it exists.
|
|
||||||
func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) {
|
|
||||||
channel, ok := c.channels[chanid]
|
|
||||||
return channel, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert adds the entry to the channel cache. If an entry for chanid already
|
|
||||||
// exists, it will be replaced with the new entry. If the entry doesn't exist,
|
|
||||||
// it will be inserted to the cache, performing a random eviction if the cache
|
|
||||||
// is at capacity.
|
|
||||||
func (c *channelCache) insert(chanid uint64, channel ChannelEdge) {
|
|
||||||
// If entry exists, replace it.
|
|
||||||
if _, ok := c.channels[chanid]; ok {
|
|
||||||
c.channels[chanid] = channel
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, evict an entry at random and insert.
|
|
||||||
if len(c.channels) == c.n {
|
|
||||||
for id := range c.channels {
|
|
||||||
delete(c.channels, id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.channels[chanid] = channel
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove deletes an edge for chanid from the cache, if it exists.
|
|
||||||
func (c *channelCache) remove(chanid uint64) {
|
|
||||||
delete(c.channels, chanid)
|
|
||||||
}
|
|
@ -1,105 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestChannelCache checks the behavior of the channelCache with respect to
|
|
||||||
// insertion, eviction, and removal of cache entries.
|
|
||||||
func TestChannelCache(t *testing.T) {
|
|
||||||
const cacheSize = 100
|
|
||||||
|
|
||||||
// Create a new channel cache with the configured max size.
|
|
||||||
c := newChannelCache(cacheSize)
|
|
||||||
|
|
||||||
// As a sanity check, assert that querying the empty cache does not
|
|
||||||
// return an entry.
|
|
||||||
_, ok := c.get(0)
|
|
||||||
if ok {
|
|
||||||
t.Fatalf("channel cache should be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, fill up the cache entirely.
|
|
||||||
for i := uint64(0); i < cacheSize; i++ {
|
|
||||||
c.insert(i, channelForInt(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the cache has all of the entries just inserted, since no
|
|
||||||
// eviction should occur until we try to surpass the max size.
|
|
||||||
assertHasChanEntries(t, c, 0, cacheSize)
|
|
||||||
|
|
||||||
// Now, insert a new element that causes the cache to evict an element.
|
|
||||||
c.insert(cacheSize, channelForInt(cacheSize))
|
|
||||||
|
|
||||||
// Assert that the cache has this last entry, as the cache should evict
|
|
||||||
// some prior element and not the newly inserted one.
|
|
||||||
assertHasChanEntries(t, c, cacheSize, cacheSize)
|
|
||||||
|
|
||||||
// Iterate over all inserted elements and construct a set of the evicted
|
|
||||||
// elements.
|
|
||||||
evicted := make(map[uint64]struct{})
|
|
||||||
for i := uint64(0); i < cacheSize+1; i++ {
|
|
||||||
_, ok := c.get(i)
|
|
||||||
if !ok {
|
|
||||||
evicted[i] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that exactly one element has been evicted.
|
|
||||||
numEvicted := len(evicted)
|
|
||||||
if numEvicted != 1 {
|
|
||||||
t.Fatalf("expected one evicted entry, got: %d", numEvicted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the highest item which initially caused the eviction and
|
|
||||||
// reinsert the element that was evicted prior.
|
|
||||||
c.remove(cacheSize)
|
|
||||||
for i := range evicted {
|
|
||||||
c.insert(i, channelForInt(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since the removal created an extra slot, the last insertion should
|
|
||||||
// not have caused an eviction and the entries for all channels in the
|
|
||||||
// original set that filled the cache should be present.
|
|
||||||
assertHasChanEntries(t, c, 0, cacheSize)
|
|
||||||
|
|
||||||
// Finally, reinsert the existing set back into the cache and test that
|
|
||||||
// the cache still has all the entries. If the randomized eviction were
|
|
||||||
// happening on inserts for existing cache items, we expect this to fail
|
|
||||||
// with high probability.
|
|
||||||
for i := uint64(0); i < cacheSize; i++ {
|
|
||||||
c.insert(i, channelForInt(i))
|
|
||||||
}
|
|
||||||
assertHasChanEntries(t, c, 0, cacheSize)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertHasEntries queries the edge cache for all channels in the range [start,
|
|
||||||
// end), asserting that they exist and their value matches the entry produced by
|
|
||||||
// entryForInt.
|
|
||||||
func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for i := start; i < end; i++ {
|
|
||||||
entry, ok := c.get(i)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("channel cache should contain chan %d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
expEntry := channelForInt(i)
|
|
||||||
if !reflect.DeepEqual(entry, expEntry) {
|
|
||||||
t.Fatalf("entry mismatch, want: %v, got: %v",
|
|
||||||
expEntry, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// channelForInt generates a unique ChannelEdge given an integer.
|
|
||||||
func channelForInt(i uint64) ChannelEdge {
|
|
||||||
return ChannelEdge{
|
|
||||||
Info: &ChannelEdgeInfo{
|
|
||||||
ChannelID: i,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
@ -4,18 +4,13 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
|
_ "github.com/btcsuite/btcwallet/walletdb/bdb"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
"github.com/lightningnetwork/lnd/keychain"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/shachain"
|
"github.com/lightningnetwork/lnd/shachain"
|
||||||
@ -66,8 +61,6 @@ var (
|
|||||||
LockTime: 5,
|
LockTime: 5,
|
||||||
}
|
}
|
||||||
privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
||||||
|
|
||||||
wireSig, _ = lnwire.NewSigFromSignature(testSig)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// makeTestDB creates a new instance of the ChannelDB for testing purposes. A
|
// makeTestDB creates a new instance of the ChannelDB for testing purposes. A
|
||||||
@ -223,819 +216,6 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) {
|
|||||||
RevocationProducer: producer,
|
RevocationProducer: producer,
|
||||||
RevocationStore: store,
|
RevocationStore: store,
|
||||||
Db: cdb,
|
Db: cdb,
|
||||||
Packager: NewChannelPackager(chanID),
|
|
||||||
FundingTxn: testTx,
|
FundingTxn: testTx,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenChannelPutGetDelete(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// Create the test channel state, then add an additional fake HTLC
|
|
||||||
// before syncing to disk.
|
|
||||||
state, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
state.LocalCommitment.Htlcs = []HTLC{
|
|
||||||
{
|
|
||||||
Signature: testSig.Serialize(),
|
|
||||||
Incoming: true,
|
|
||||||
Amt: 10,
|
|
||||||
RHash: key,
|
|
||||||
RefundTimeout: 1,
|
|
||||||
OnionBlob: []byte("onionblob"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
state.RemoteCommitment.Htlcs = []HTLC{
|
|
||||||
{
|
|
||||||
Signature: testSig.Serialize(),
|
|
||||||
Incoming: false,
|
|
||||||
Amt: 10,
|
|
||||||
RHash: key,
|
|
||||||
RefundTimeout: 1,
|
|
||||||
OnionBlob: []byte("onionblob"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18556,
|
|
||||||
}
|
|
||||||
if err := state.SyncPending(addr, 101); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch open channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newState := openChannels[0]
|
|
||||||
|
|
||||||
// The decoded channel state should be identical to what we stored
|
|
||||||
// above.
|
|
||||||
if !reflect.DeepEqual(state, newState) {
|
|
||||||
t.Fatalf("channel state doesn't match:: %v vs %v",
|
|
||||||
spew.Sdump(state), spew.Sdump(newState))
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll also test that the channel is properly able to hot swap the
|
|
||||||
// next revocation for the state machine. This tests the initial
|
|
||||||
// post-funding revocation exchange.
|
|
||||||
nextRevKey, err := btcec.NewPrivateKey(btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create new private key: %v", err)
|
|
||||||
}
|
|
||||||
if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil {
|
|
||||||
t.Fatalf("unable to update revocation: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
openChannels, err = cdb.FetchOpenChannels(state.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch open channel: %v", err)
|
|
||||||
}
|
|
||||||
updatedChan := openChannels[0]
|
|
||||||
|
|
||||||
// Ensure that the revocation was set properly.
|
|
||||||
if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) {
|
|
||||||
t.Fatalf("next revocation wasn't updated")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally to wrap up the test, delete the state of the channel within
|
|
||||||
// the database. This involves "closing" the channel which removes all
|
|
||||||
// written state, and creates a small "summary" elsewhere within the
|
|
||||||
// database.
|
|
||||||
closeSummary := &ChannelCloseSummary{
|
|
||||||
ChanPoint: state.FundingOutpoint,
|
|
||||||
RemotePub: state.IdentityPub,
|
|
||||||
SettledBalance: btcutil.Amount(500),
|
|
||||||
TimeLockedBalance: btcutil.Amount(10000),
|
|
||||||
IsPending: false,
|
|
||||||
CloseType: CooperativeClose,
|
|
||||||
}
|
|
||||||
if err := state.CloseChannel(closeSummary); err != nil {
|
|
||||||
t.Fatalf("unable to close channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// As the channel is now closed, attempting to fetch all open channels
|
|
||||||
// for our fake node ID should return an empty slice.
|
|
||||||
openChans, err := cdb.FetchOpenChannels(state.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch open channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(openChans) != 0 {
|
|
||||||
t.Fatalf("all channels not deleted, found %v", len(openChans))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Additionally, attempting to fetch all the open channels globally
|
|
||||||
// should yield no results.
|
|
||||||
openChans, err = cdb.FetchAllChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("unable to fetch all open chans")
|
|
||||||
}
|
|
||||||
if len(openChans) != 0 {
|
|
||||||
t.Fatalf("all channels not deleted, found %v", len(openChans))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
|
|
||||||
if !reflect.DeepEqual(a, b) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: commitments don't match: %v vs %v",
|
|
||||||
line, spew.Sdump(a), spew.Sdump(b))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChannelStateTransition(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// First create a minimal channel, then perform a full sync in order to
|
|
||||||
// persist the data.
|
|
||||||
channel, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18556,
|
|
||||||
}
|
|
||||||
if err := channel.SyncPending(addr, 101); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add some HTLCs which were added during this new state transition.
|
|
||||||
// Half of the HTLCs are incoming, while the other half are outgoing.
|
|
||||||
var (
|
|
||||||
htlcs []HTLC
|
|
||||||
htlcAmt lnwire.MilliSatoshi
|
|
||||||
)
|
|
||||||
for i := uint32(0); i < 10; i++ {
|
|
||||||
var incoming bool
|
|
||||||
if i > 5 {
|
|
||||||
incoming = true
|
|
||||||
}
|
|
||||||
htlc := HTLC{
|
|
||||||
Signature: testSig.Serialize(),
|
|
||||||
Incoming: incoming,
|
|
||||||
Amt: 10,
|
|
||||||
RHash: key,
|
|
||||||
RefundTimeout: i,
|
|
||||||
OutputIndex: int32(i * 3),
|
|
||||||
LogIndex: uint64(i * 2),
|
|
||||||
HtlcIndex: uint64(i),
|
|
||||||
}
|
|
||||||
htlc.OnionBlob = make([]byte, 10)
|
|
||||||
copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10))
|
|
||||||
htlcs = append(htlcs, htlc)
|
|
||||||
htlcAmt += htlc.Amt
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new channel delta which includes the above HTLCs, some
|
|
||||||
// balance updates, and an increment of the current commitment height.
|
|
||||||
// Additionally, modify the signature and commitment transaction.
|
|
||||||
newSequence := uint32(129498)
|
|
||||||
newSig := bytes.Repeat([]byte{3}, 71)
|
|
||||||
newTx := channel.LocalCommitment.CommitTx.Copy()
|
|
||||||
newTx.TxIn[0].Sequence = newSequence
|
|
||||||
commitment := ChannelCommitment{
|
|
||||||
CommitHeight: 1,
|
|
||||||
LocalLogIndex: 2,
|
|
||||||
LocalHtlcIndex: 1,
|
|
||||||
RemoteLogIndex: 2,
|
|
||||||
RemoteHtlcIndex: 1,
|
|
||||||
LocalBalance: lnwire.MilliSatoshi(1e8),
|
|
||||||
RemoteBalance: lnwire.MilliSatoshi(1e8),
|
|
||||||
CommitFee: 55,
|
|
||||||
FeePerKw: 99,
|
|
||||||
CommitTx: newTx,
|
|
||||||
CommitSig: newSig,
|
|
||||||
Htlcs: htlcs,
|
|
||||||
}
|
|
||||||
|
|
||||||
// First update the local node's broadcastable state and also add a
|
|
||||||
// CommitDiff remote node's as well in order to simulate a proper state
|
|
||||||
// transition.
|
|
||||||
if err := channel.UpdateCommitment(&commitment); err != nil {
|
|
||||||
t.Fatalf("unable to update commitment: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The balances, new update, the HTLCs and the changes to the fake
|
|
||||||
// commitment transaction along with the modified signature should all
|
|
||||||
// have been updated.
|
|
||||||
updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch updated channel: %v", err)
|
|
||||||
}
|
|
||||||
assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
|
|
||||||
numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to read commitment height from disk: %v", err)
|
|
||||||
}
|
|
||||||
if numDiskUpdates != uint64(commitment.CommitHeight) {
|
|
||||||
t.Fatalf("num disk updates doesn't match: %v vs %v",
|
|
||||||
numDiskUpdates, commitment.CommitHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempting to query for a commitment diff should return
|
|
||||||
// ErrNoPendingCommit as we haven't yet created a new state for them.
|
|
||||||
_, err = channel.RemoteCommitChainTip()
|
|
||||||
if err != ErrNoPendingCommit {
|
|
||||||
t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// To simulate us extending a new state to the remote party, we'll also
|
|
||||||
// create a new commit diff for them.
|
|
||||||
remoteCommit := commitment
|
|
||||||
remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8)
|
|
||||||
remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8)
|
|
||||||
remoteCommit.CommitHeight = 1
|
|
||||||
commitDiff := &CommitDiff{
|
|
||||||
Commitment: remoteCommit,
|
|
||||||
CommitSig: &lnwire.CommitSig{
|
|
||||||
ChanID: lnwire.ChannelID(key),
|
|
||||||
CommitSig: wireSig,
|
|
||||||
HtlcSigs: []lnwire.Sig{
|
|
||||||
wireSig,
|
|
||||||
wireSig,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LogUpdates: []LogUpdate{
|
|
||||||
{
|
|
||||||
LogIndex: 1,
|
|
||||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
|
||||||
ID: 1,
|
|
||||||
Amount: lnwire.NewMSatFromSatoshis(100),
|
|
||||||
Expiry: 25,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
LogIndex: 2,
|
|
||||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
|
||||||
ID: 2,
|
|
||||||
Amount: lnwire.NewMSatFromSatoshis(200),
|
|
||||||
Expiry: 50,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
OpenedCircuitKeys: []CircuitKey{},
|
|
||||||
ClosedCircuitKeys: []CircuitKey{},
|
|
||||||
}
|
|
||||||
copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
|
|
||||||
bytes.Repeat([]byte{1}, 32))
|
|
||||||
copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
|
|
||||||
bytes.Repeat([]byte{2}, 32))
|
|
||||||
if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
|
|
||||||
t.Fatalf("unable to add to commit chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The commitment tip should now match the commitment that we just
|
|
||||||
// inserted.
|
|
||||||
diskCommitDiff, err := channel.RemoteCommitChainTip()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch commit diff: %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(commitDiff, diskCommitDiff) {
|
|
||||||
t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit),
|
|
||||||
spew.Sdump(diskCommitDiff))
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll save the old remote commitment as this will be added to the
|
|
||||||
// revocation log shortly.
|
|
||||||
oldRemoteCommit := channel.RemoteCommitment
|
|
||||||
|
|
||||||
// Next, write to the log which tracks the necessary revocation state
|
|
||||||
// needed to rectify any fishy behavior by the remote party. Modify the
|
|
||||||
// current uncollapsed revocation state to simulate a state transition
|
|
||||||
// by the remote party.
|
|
||||||
channel.RemoteCurrentRevocation = channel.RemoteNextRevocation
|
|
||||||
newPriv, err := btcec.NewPrivateKey(btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate key: %v", err)
|
|
||||||
}
|
|
||||||
channel.RemoteNextRevocation = newPriv.PubKey()
|
|
||||||
|
|
||||||
fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
|
|
||||||
diskCommitDiff.LogUpdates, nil)
|
|
||||||
|
|
||||||
err = channel.AdvanceCommitChainTail(fwdPkg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to append to revocation log: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// At this point, the remote commit chain should be nil, and the posted
|
|
||||||
// remote commitment should match the one we added as a diff above.
|
|
||||||
if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit {
|
|
||||||
t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should be able to fetch the channel delta created above by its
|
|
||||||
// update number with all the state properly reconstructed.
|
|
||||||
diskPrevCommit, err := channel.FindPreviousState(
|
|
||||||
oldRemoteCommit.CommitHeight,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch past delta: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The two deltas (the original vs the on-disk version) should
|
|
||||||
// identical, and all HTLC data should properly be retained.
|
|
||||||
assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit)
|
|
||||||
|
|
||||||
// The state number recovered from the tail of the revocation log
|
|
||||||
// should be identical to this current state.
|
|
||||||
logTail, err := channel.RevocationLogTail()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to retrieve log: %v", err)
|
|
||||||
}
|
|
||||||
if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
|
|
||||||
t.Fatal("update number doesn't match")
|
|
||||||
}
|
|
||||||
|
|
||||||
oldRemoteCommit = channel.RemoteCommitment
|
|
||||||
|
|
||||||
// Next modify the posted diff commitment slightly, then create a new
|
|
||||||
// commitment diff and advance the tail.
|
|
||||||
commitDiff.Commitment.CommitHeight = 2
|
|
||||||
commitDiff.Commitment.LocalBalance -= htlcAmt
|
|
||||||
commitDiff.Commitment.RemoteBalance += htlcAmt
|
|
||||||
commitDiff.LogUpdates = []LogUpdate{}
|
|
||||||
if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
|
|
||||||
t.Fatalf("unable to add to commit chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
|
|
||||||
|
|
||||||
err = channel.AdvanceCommitChainTail(fwdPkg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to append to revocation log: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once again, fetch the state and ensure it has been properly updated.
|
|
||||||
prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch past delta: %v", err)
|
|
||||||
}
|
|
||||||
assertCommitmentEqual(t, &oldRemoteCommit, prevCommit)
|
|
||||||
|
|
||||||
// Once again, state number recovered from the tail of the revocation
|
|
||||||
// log should be identical to this current state.
|
|
||||||
logTail, err = channel.RevocationLogTail()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to retrieve log: %v", err)
|
|
||||||
}
|
|
||||||
if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
|
|
||||||
t.Fatal("update number doesn't match")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The revocation state stored on-disk should now also be identical.
|
|
||||||
updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch updated channel: %v", err)
|
|
||||||
}
|
|
||||||
if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) {
|
|
||||||
t.Fatalf("revocation state was not synced")
|
|
||||||
}
|
|
||||||
if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) {
|
|
||||||
t.Fatalf("revocation state was not synced")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now attempt to delete the channel from the database.
|
|
||||||
closeSummary := &ChannelCloseSummary{
|
|
||||||
ChanPoint: channel.FundingOutpoint,
|
|
||||||
RemotePub: channel.IdentityPub,
|
|
||||||
SettledBalance: btcutil.Amount(500),
|
|
||||||
TimeLockedBalance: btcutil.Amount(10000),
|
|
||||||
IsPending: false,
|
|
||||||
CloseType: RemoteForceClose,
|
|
||||||
}
|
|
||||||
if err := updatedChannel[0].CloseChannel(closeSummary); err != nil {
|
|
||||||
t.Fatalf("unable to delete updated channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we attempt to fetch the target channel again, it shouldn't be
|
|
||||||
// found.
|
|
||||||
channels, err := cdb.FetchOpenChannels(channel.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch updated channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(channels) != 0 {
|
|
||||||
t.Fatalf("%v channels, found, but none should be",
|
|
||||||
len(channels))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempting to find previous states on the channel should fail as the
|
|
||||||
// revocation log has been deleted.
|
|
||||||
_, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("revocation log search should have failed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchPendingChannels(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// Create first test channel state
|
|
||||||
state, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18555,
|
|
||||||
}
|
|
||||||
|
|
||||||
const broadcastHeight = 99
|
|
||||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pendingChannels, err := cdb.FetchPendingChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to list pending channels: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(pendingChannels) != 1 {
|
|
||||||
t.Fatalf("incorrect number of pending channels: expecting %v,"+
|
|
||||||
"got %v", 1, len(pendingChannels))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The broadcast height of the pending channel should have been set
|
|
||||||
// properly.
|
|
||||||
if pendingChannels[0].FundingBroadcastHeight != broadcastHeight {
|
|
||||||
t.Fatalf("broadcast height mismatch: expected %v, got %v",
|
|
||||||
pendingChannels[0].FundingBroadcastHeight,
|
|
||||||
broadcastHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanOpenLoc := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: 5,
|
|
||||||
TxIndex: 10,
|
|
||||||
TxPosition: 15,
|
|
||||||
}
|
|
||||||
err = pendingChannels[0].MarkAsOpen(chanOpenLoc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to mark channel as open: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if pendingChannels[0].IsPending {
|
|
||||||
t.Fatalf("channel marked open should no longer be pending")
|
|
||||||
}
|
|
||||||
|
|
||||||
if pendingChannels[0].ShortChanID() != chanOpenLoc {
|
|
||||||
t.Fatalf("channel opening height not updated: expected %v, "+
|
|
||||||
"got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
|
|
||||||
chanOpenLoc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll re-fetch the channel to ensure that the open height was
|
|
||||||
// properly set.
|
|
||||||
openChans, err := cdb.FetchAllChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channels: %v", err)
|
|
||||||
}
|
|
||||||
if openChans[0].ShortChanID() != chanOpenLoc {
|
|
||||||
t.Fatalf("channel opening heights don't match: expected %v, "+
|
|
||||||
"got %v", spew.Sdump(openChans[0].ShortChanID()),
|
|
||||||
chanOpenLoc)
|
|
||||||
}
|
|
||||||
if openChans[0].FundingBroadcastHeight != broadcastHeight {
|
|
||||||
t.Fatalf("broadcast height mismatch: expected %v, got %v",
|
|
||||||
openChans[0].FundingBroadcastHeight,
|
|
||||||
broadcastHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
pendingChannels, err = cdb.FetchPendingChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to list pending channels: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(pendingChannels) != 0 {
|
|
||||||
t.Fatalf("incorrect number of pending channels: expecting %v,"+
|
|
||||||
"got %v", 0, len(pendingChannels))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFetchClosedChannels(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// First create a test channel, that we'll be closing within this pull
|
|
||||||
// request.
|
|
||||||
state, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next sync the channel to disk, marking it as being in a pending open
|
|
||||||
// state.
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18555,
|
|
||||||
}
|
|
||||||
const broadcastHeight = 99
|
|
||||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, simulate the confirmation of the channel by marking it as
|
|
||||||
// pending within the database.
|
|
||||||
chanOpenLoc := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: 5,
|
|
||||||
TxIndex: 10,
|
|
||||||
TxPosition: 15,
|
|
||||||
}
|
|
||||||
err = state.MarkAsOpen(chanOpenLoc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to mark channel as open: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, close the channel by including a close channel summary in the
|
|
||||||
// database.
|
|
||||||
summary := &ChannelCloseSummary{
|
|
||||||
ChanPoint: state.FundingOutpoint,
|
|
||||||
ClosingTXID: rev,
|
|
||||||
RemotePub: state.IdentityPub,
|
|
||||||
Capacity: state.Capacity,
|
|
||||||
SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(),
|
|
||||||
TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000,
|
|
||||||
CloseType: RemoteForceClose,
|
|
||||||
IsPending: true,
|
|
||||||
LocalChanConfig: state.LocalChanCfg,
|
|
||||||
}
|
|
||||||
if err := state.CloseChannel(summary); err != nil {
|
|
||||||
t.Fatalf("unable to close channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query the database to ensure that the channel has now been properly
|
|
||||||
// closed. We should get the same result whether querying for pending
|
|
||||||
// channels only, or not.
|
|
||||||
pendingClosed, err := cdb.FetchClosedChannels(true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed fetching closed channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(pendingClosed) != 1 {
|
|
||||||
t.Fatalf("incorrect number of pending closed channels: expecting %v,"+
|
|
||||||
"got %v", 1, len(pendingClosed))
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(summary, pendingClosed[0]) {
|
|
||||||
t.Fatalf("database summaries don't match: expected %v got %v",
|
|
||||||
spew.Sdump(summary), spew.Sdump(pendingClosed[0]))
|
|
||||||
}
|
|
||||||
closed, err := cdb.FetchClosedChannels(false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed fetching all closed channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(closed) != 1 {
|
|
||||||
t.Fatalf("incorrect number of closed channels: expecting %v, "+
|
|
||||||
"got %v", 1, len(closed))
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(summary, closed[0]) {
|
|
||||||
t.Fatalf("database summaries don't match: expected %v got %v",
|
|
||||||
spew.Sdump(summary), spew.Sdump(closed[0]))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark the channel as fully closed.
|
|
||||||
err = cdb.MarkChanFullyClosed(&state.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed fully closing channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The channel should no longer be considered pending, but should still
|
|
||||||
// be retrieved when fetching all the closed channels.
|
|
||||||
closed, err = cdb.FetchClosedChannels(false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed fetching closed channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(closed) != 1 {
|
|
||||||
t.Fatalf("incorrect number of closed channels: expecting %v, "+
|
|
||||||
"got %v", 1, len(closed))
|
|
||||||
}
|
|
||||||
pendingClose, err := cdb.FetchClosedChannels(true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed fetching channels pending close: %v", err)
|
|
||||||
}
|
|
||||||
if len(pendingClose) != 0 {
|
|
||||||
t.Fatalf("incorrect number of closed channels: expecting %v, "+
|
|
||||||
"got %v", 0, len(closed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFetchWaitingCloseChannels ensures that the correct channels that are
|
|
||||||
// waiting to be closed are returned.
|
|
||||||
func TestFetchWaitingCloseChannels(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
const numChannels = 2
|
|
||||||
const broadcastHeight = 99
|
|
||||||
addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555}
|
|
||||||
|
|
||||||
// We'll start by creating two channels within our test database. One of
|
|
||||||
// them will have their funding transaction confirmed on-chain, while
|
|
||||||
// the other one will remain unconfirmed.
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
channels := make([]*OpenChannel, numChannels)
|
|
||||||
for i := 0; i < numChannels; i++ {
|
|
||||||
channel, err := createTestChannelState(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel: %v", err)
|
|
||||||
}
|
|
||||||
err = channel.SyncPending(addr, broadcastHeight)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to sync channel: %v", err)
|
|
||||||
}
|
|
||||||
channels[i] = channel
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll only confirm the first one.
|
|
||||||
channelConf := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: broadcastHeight + 1,
|
|
||||||
TxIndex: 10,
|
|
||||||
TxPosition: 15,
|
|
||||||
}
|
|
||||||
if err := channels[0].MarkAsOpen(channelConf); err != nil {
|
|
||||||
t.Fatalf("unable to mark channel as open: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Then, we'll mark the channels as if their commitments were broadcast.
|
|
||||||
// This would happen in the event of a force close and should make the
|
|
||||||
// channels enter a state of waiting close.
|
|
||||||
for _, channel := range channels {
|
|
||||||
closeTx := wire.NewMsgTx(2)
|
|
||||||
closeTx.AddTxIn(
|
|
||||||
&wire.TxIn{
|
|
||||||
PreviousOutPoint: channel.FundingOutpoint,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil {
|
|
||||||
t.Fatalf("unable to mark commitment broadcast: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, we'll fetch all the channels waiting to be closed from the
|
|
||||||
// database. We should expect to see both channels above, even if any of
|
|
||||||
// them haven't had their funding transaction confirm on-chain.
|
|
||||||
waitingCloseChannels, err := db.FetchWaitingCloseChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch all waiting close channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(waitingCloseChannels) != 2 {
|
|
||||||
t.Fatalf("expected %d channels waiting to be closed, got %d", 2,
|
|
||||||
len(waitingCloseChannels))
|
|
||||||
}
|
|
||||||
expectedChannels := make(map[wire.OutPoint]struct{})
|
|
||||||
for _, channel := range channels {
|
|
||||||
expectedChannels[channel.FundingOutpoint] = struct{}{}
|
|
||||||
}
|
|
||||||
for _, channel := range waitingCloseChannels {
|
|
||||||
if _, ok := expectedChannels[channel.FundingOutpoint]; !ok {
|
|
||||||
t.Fatalf("expected channel %v to be waiting close",
|
|
||||||
channel.FundingOutpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, make sure we can retrieve the closing tx for the
|
|
||||||
// channel.
|
|
||||||
closeTx, err := channel.BroadcastedCommitment()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Unable to retrieve commitment: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint {
|
|
||||||
t.Fatalf("expected outpoint %v, got %v",
|
|
||||||
channel.FundingOutpoint,
|
|
||||||
closeTx.TxIn[0].PreviousOutPoint)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory
|
|
||||||
// short channel ID of another OpenChannel to reflect a preceding call to
|
|
||||||
// MarkOpen on a different OpenChannel.
|
|
||||||
func TestRefreshShortChanID(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// First create a test channel.
|
|
||||||
state, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18555,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark the channel as pending within the channeldb.
|
|
||||||
const broadcastHeight = 99
|
|
||||||
if err := state.SyncPending(addr, broadcastHeight); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, locate the pending channel with the database.
|
|
||||||
pendingChannels, err := cdb.FetchPendingChannels()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to load pending channels; %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var pendingChannel *OpenChannel
|
|
||||||
for _, channel := range pendingChannels {
|
|
||||||
if channel.FundingOutpoint == state.FundingOutpoint {
|
|
||||||
pendingChannel = channel
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if pendingChannel == nil {
|
|
||||||
t.Fatalf("unable to find pending channel with funding "+
|
|
||||||
"outpoint=%v: %v", state.FundingOutpoint, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, simulate the confirmation of the channel by marking it as
|
|
||||||
// pending within the database.
|
|
||||||
chanOpenLoc := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: 105,
|
|
||||||
TxIndex: 10,
|
|
||||||
TxPosition: 15,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = state.MarkAsOpen(chanOpenLoc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to mark channel open: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The short_chan_id of the receiver to MarkAsOpen should reflect the
|
|
||||||
// open location, but the other pending channel should remain unchanged.
|
|
||||||
if state.ShortChanID() == pendingChannel.ShortChanID() {
|
|
||||||
t.Fatalf("pending channel short_chan_ID should not have been " +
|
|
||||||
"updated before refreshing short_chan_id")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the receiver's short channel id has been updated, check to
|
|
||||||
// ensure that the channel packager's source has been updated as well.
|
|
||||||
// This ensures that the packager will read and write to buckets
|
|
||||||
// corresponding to the new short chan id, instead of the prior.
|
|
||||||
if state.Packager.(*ChannelPackager).source != chanOpenLoc {
|
|
||||||
t.Fatalf("channel packager source was not updated: want %v, "+
|
|
||||||
"got %v", chanOpenLoc,
|
|
||||||
state.Packager.(*ChannelPackager).source)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, refresh the short channel ID of the pending channel.
|
|
||||||
err = pendingChannel.RefreshShortChanID()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to refresh short_chan_id: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This should result in both OpenChannel's now having the same
|
|
||||||
// ShortChanID.
|
|
||||||
if state.ShortChanID() != pendingChannel.ShortChanID() {
|
|
||||||
t.Fatalf("expected pending channel short_chan_id to be "+
|
|
||||||
"refreshed: want %v, got %v", state.ShortChanID(),
|
|
||||||
pendingChannel.ShortChanID())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to ensure that the _other_ OpenChannel channel packager's
|
|
||||||
// source has also been updated after the refresh. This ensures that the
|
|
||||||
// other packagers will read and write to buckets corresponding to the
|
|
||||||
// updated short chan id.
|
|
||||||
if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
|
|
||||||
t.Fatalf("channel packager source was not updated: want %v, "+
|
|
||||||
"got %v", chanOpenLoc,
|
|
||||||
pendingChannel.Packager.(*ChannelPackager).source)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -48,12 +48,6 @@ type UnknownElementType struct {
|
|||||||
element interface{}
|
element interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUnknownElementType creates a new UnknownElementType error from the passed
|
|
||||||
// method name and element.
|
|
||||||
func NewUnknownElementType(method string, el interface{}) UnknownElementType {
|
|
||||||
return UnknownElementType{method: method, element: el}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error returns the name of the method that encountered the error, as well as
|
// Error returns the name of the method that encountered the error, as well as
|
||||||
// the type that was unsupported.
|
// the type that was unsupported.
|
||||||
func (e UnknownElementType) Error() string {
|
func (e UnknownElementType) Error() string {
|
||||||
|
@ -4,16 +4,11 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
"github.com/coreos/bbolt"
|
"github.com/coreos/bbolt"
|
||||||
"github.com/go-errors/errors"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -87,57 +82,6 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) {
|
|||||||
return chanDB, nil
|
return chanDB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Path returns the file path to the channel database.
|
|
||||||
func (d *DB) Path() string {
|
|
||||||
return d.dbPath
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wipe completely deletes all saved state within all used buckets within the
|
|
||||||
// database. The deletion is done in a single transaction, therefore this
|
|
||||||
// operation is fully atomic.
|
|
||||||
func (d *DB) Wipe() error {
|
|
||||||
return d.Update(func(tx *bbolt.Tx) error {
|
|
||||||
err := tx.DeleteBucket(openChannelBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.DeleteBucket(closedChannelBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.DeleteBucket(invoiceBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.DeleteBucket(nodeInfoBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.DeleteBucket(nodeBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.DeleteBucket(edgeBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.DeleteBucket(edgeIndexBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = tx.DeleteBucket(graphMetaBucket)
|
|
||||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// createChannelDB creates and initializes a fresh version of channeldb. In
|
// createChannelDB creates and initializes a fresh version of channeldb. In
|
||||||
// the case that the target path has not yet been created or doesn't yet exist,
|
// the case that the target path has not yet been created or doesn't yet exist,
|
||||||
// then the path is created. Additionally, all required top-level buckets used
|
// then the path is created. Additionally, all required top-level buckets used
|
||||||
@ -163,14 +107,6 @@ func createChannelDB(dbPath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := tx.CreateBucket(forwardingLogBucket); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tx.CreateBucket(fwdPackagesKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tx.CreateBucket(invoiceBucket); err != nil {
|
if _, err := tx.CreateBucket(invoiceBucket); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -179,10 +115,6 @@ func createChannelDB(dbPath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := tx.CreateBucket(nodeInfoBucket); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes, err := tx.CreateBucket(nodeBucket)
|
nodes, err := tx.CreateBucket(nodeBucket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -249,359 +181,6 @@ func fileExists(path string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchOpenChannels starts a new database transaction and returns all stored
|
|
||||||
// currently active/open channels associated with the target nodeID. In the case
|
|
||||||
// that no active channels are known to have been created with this node, then a
|
|
||||||
// zero-length slice is returned.
|
|
||||||
func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
|
|
||||||
var channels []*OpenChannel
|
|
||||||
err := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
channels, err = d.fetchOpenChannels(tx, nodeID)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
return channels, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchOpenChannels uses and existing database transaction and returns all
|
|
||||||
// stored currently active/open channels associated with the target nodeID. In
|
|
||||||
// the case that no active channels are known to have been created with this
|
|
||||||
// node, then a zero-length slice is returned.
|
|
||||||
func (d *DB) fetchOpenChannels(tx *bbolt.Tx,
|
|
||||||
nodeID *btcec.PublicKey) ([]*OpenChannel, error) {
|
|
||||||
|
|
||||||
// Get the bucket dedicated to storing the metadata for open channels.
|
|
||||||
openChanBucket := tx.Bucket(openChannelBucket)
|
|
||||||
if openChanBucket == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Within this top level bucket, fetch the bucket dedicated to storing
|
|
||||||
// open channel data specific to the remote node.
|
|
||||||
pub := nodeID.SerializeCompressed()
|
|
||||||
nodeChanBucket := openChanBucket.Bucket(pub)
|
|
||||||
if nodeChanBucket == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll need to go down an additional layer in order to retrieve
|
|
||||||
// the channels for each chain the node knows of.
|
|
||||||
var channels []*OpenChannel
|
|
||||||
err := nodeChanBucket.ForEach(func(chainHash, v []byte) error {
|
|
||||||
// If there's a value, it's not a bucket so ignore it.
|
|
||||||
if v != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've found a valid chainhash bucket, then we'll retrieve
|
|
||||||
// that so we can extract all the channels.
|
|
||||||
chainBucket := nodeChanBucket.Bucket(chainHash)
|
|
||||||
if chainBucket == nil {
|
|
||||||
return fmt.Errorf("unable to read bucket for chain=%x",
|
|
||||||
chainHash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we both of the necessary buckets retrieved, fetch
|
|
||||||
// all the active channels related to this node.
|
|
||||||
nodeChannels, err := d.fetchNodeChannels(chainBucket)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to read channel for "+
|
|
||||||
"chain_hash=%x, node_key=%x: %v",
|
|
||||||
chainHash[:], pub, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
channels = append(channels, nodeChannels...)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return channels, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchNodeChannels retrieves all active channels from the target chainBucket
|
|
||||||
// which is under a node's dedicated channel bucket. This function is typically
|
|
||||||
// used to fetch all the active channels related to a particular node.
|
|
||||||
func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) {
|
|
||||||
|
|
||||||
var channels []*OpenChannel
|
|
||||||
|
|
||||||
// A node may have channels on several chains, so for each known chain,
|
|
||||||
// we'll extract all the channels.
|
|
||||||
err := chainBucket.ForEach(func(chanPoint, v []byte) error {
|
|
||||||
// If there's a value, it's not a bucket so ignore it.
|
|
||||||
if v != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once we've found a valid channel bucket, we'll extract it
|
|
||||||
// from the node's chain bucket.
|
|
||||||
chanBucket := chainBucket.Bucket(chanPoint)
|
|
||||||
|
|
||||||
var outPoint wire.OutPoint
|
|
||||||
err := readOutpoint(bytes.NewReader(chanPoint), &outPoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
oChannel, err := fetchOpenChannel(chanBucket, &outPoint)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to read channel data for "+
|
|
||||||
"chan_point=%v: %v", outPoint, err)
|
|
||||||
}
|
|
||||||
oChannel.Db = d
|
|
||||||
|
|
||||||
channels = append(channels, oChannel)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return channels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchChannel attempts to locate a channel specified by the passed channel
|
|
||||||
// point. If the channel cannot be found, then an error will be returned.
|
|
||||||
func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) {
|
|
||||||
var (
|
|
||||||
targetChan *OpenChannel
|
|
||||||
targetChanPoint bytes.Buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// chanScan will traverse the following bucket structure:
|
|
||||||
// * nodePub => chainHash => chanPoint
|
|
||||||
//
|
|
||||||
// At each level we go one further, ensuring that we're traversing the
|
|
||||||
// proper key (that's actually a bucket). By only reading the bucket
|
|
||||||
// structure and skipping fully decoding each channel, we save a good
|
|
||||||
// bit of CPU as we don't need to do things like decompress public
|
|
||||||
// keys.
|
|
||||||
chanScan := func(tx *bbolt.Tx) error {
|
|
||||||
// Get the bucket dedicated to storing the metadata for open
|
|
||||||
// channels.
|
|
||||||
openChanBucket := tx.Bucket(openChannelBucket)
|
|
||||||
if openChanBucket == nil {
|
|
||||||
return ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
// Within the node channel bucket, are the set of node pubkeys
|
|
||||||
// we have channels with, we don't know the entire set, so
|
|
||||||
// we'll check them all.
|
|
||||||
return openChanBucket.ForEach(func(nodePub, v []byte) error {
|
|
||||||
// Ensure that this is a key the same size as a pubkey,
|
|
||||||
// and also that it leads directly to a bucket.
|
|
||||||
if len(nodePub) != 33 || v != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeChanBucket := openChanBucket.Bucket(nodePub)
|
|
||||||
if nodeChanBucket == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// The next layer down is all the chains that this node
|
|
||||||
// has channels on with us.
|
|
||||||
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
|
|
||||||
// If there's a value, it's not a bucket so
|
|
||||||
// ignore it.
|
|
||||||
if v != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chainBucket := nodeChanBucket.Bucket(chainHash)
|
|
||||||
if chainBucket == nil {
|
|
||||||
return fmt.Errorf("unable to read "+
|
|
||||||
"bucket for chain=%x", chainHash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally we reach the leaf bucket that stores
|
|
||||||
// all the chanPoints for this node.
|
|
||||||
chanBucket := chainBucket.Bucket(
|
|
||||||
targetChanPoint.Bytes(),
|
|
||||||
)
|
|
||||||
if chanBucket == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := fetchOpenChannel(
|
|
||||||
chanBucket, &chanPoint,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
targetChan = channel
|
|
||||||
targetChan.Db = d
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
err := d.View(chanScan)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetChan != nil {
|
|
||||||
return targetChan, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we can't find the channel, then we return with an error, as we
|
|
||||||
// have nothing to backup.
|
|
||||||
return nil, ErrChannelNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchAllChannels attempts to retrieve all open channels currently stored
|
|
||||||
// within the database, including pending open, fully open and channels waiting
|
|
||||||
// for a closing transaction to confirm.
|
|
||||||
func (d *DB) FetchAllChannels() ([]*OpenChannel, error) {
|
|
||||||
var channels []*OpenChannel
|
|
||||||
|
|
||||||
// TODO(halseth): fetch all in one db tx.
|
|
||||||
openChannels, err := d.FetchAllOpenChannels()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
channels = append(channels, openChannels...)
|
|
||||||
|
|
||||||
pendingChannels, err := d.FetchPendingChannels()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
channels = append(channels, pendingChannels...)
|
|
||||||
|
|
||||||
waitingClose, err := d.FetchWaitingCloseChannels()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
channels = append(channels, waitingClose...)
|
|
||||||
|
|
||||||
return channels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchAllOpenChannels will return all channels that have the funding
|
|
||||||
// transaction confirmed, and is not waiting for a closing transaction to be
|
|
||||||
// confirmed.
|
|
||||||
func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) {
|
|
||||||
return fetchChannels(d, false, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchPendingChannels will return channels that have completed the process of
|
|
||||||
// generating and broadcasting funding transactions, but whose funding
|
|
||||||
// transactions have yet to be confirmed on the blockchain.
|
|
||||||
func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) {
|
|
||||||
return fetchChannels(d, true, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchWaitingCloseChannels will return all channels that have been opened,
|
|
||||||
// but are now waiting for a closing transaction to be confirmed.
|
|
||||||
//
|
|
||||||
// NOTE: This includes channels that are also pending to be opened.
|
|
||||||
func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) {
|
|
||||||
waitingClose, err := fetchChannels(d, false, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pendingWaitingClose, err := fetchChannels(d, true, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(waitingClose, pendingWaitingClose...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchChannels attempts to retrieve channels currently stored in the
|
|
||||||
// database. The pending parameter determines whether only pending channels
|
|
||||||
// will be returned, or only open channels will be returned. The waitingClose
|
|
||||||
// parameter determines whether only channels waiting for a closing transaction
|
|
||||||
// to be confirmed should be returned. If no active channels exist within the
|
|
||||||
// network, then ErrNoActiveChannels is returned.
|
|
||||||
func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) {
|
|
||||||
var channels []*OpenChannel
|
|
||||||
|
|
||||||
err := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
// Get the bucket dedicated to storing the metadata for open
|
|
||||||
// channels.
|
|
||||||
openChanBucket := tx.Bucket(openChannelBucket)
|
|
||||||
if openChanBucket == nil {
|
|
||||||
return ErrNoActiveChannels
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, fetch the bucket dedicated to storing metadata related
|
|
||||||
// to all nodes. All keys within this bucket are the serialized
|
|
||||||
// public keys of all our direct counterparties.
|
|
||||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
|
||||||
if nodeMetaBucket == nil {
|
|
||||||
return fmt.Errorf("node bucket not created")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally for each node public key in the bucket, fetch all
|
|
||||||
// the channels related to this particular node.
|
|
||||||
return nodeMetaBucket.ForEach(func(k, v []byte) error {
|
|
||||||
nodeChanBucket := openChanBucket.Bucket(k)
|
|
||||||
if nodeChanBucket == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodeChanBucket.ForEach(func(chainHash, v []byte) error {
|
|
||||||
// If there's a value, it's not a bucket so
|
|
||||||
// ignore it.
|
|
||||||
if v != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've found a valid chainhash bucket,
|
|
||||||
// then we'll retrieve that so we can extract
|
|
||||||
// all the channels.
|
|
||||||
chainBucket := nodeChanBucket.Bucket(chainHash)
|
|
||||||
if chainBucket == nil {
|
|
||||||
return fmt.Errorf("unable to read "+
|
|
||||||
"bucket for chain=%x", chainHash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeChans, err := d.fetchNodeChannels(chainBucket)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to read "+
|
|
||||||
"channel for chain_hash=%x, "+
|
|
||||||
"node_key=%x: %v", chainHash[:], k, err)
|
|
||||||
}
|
|
||||||
for _, channel := range nodeChans {
|
|
||||||
if channel.IsPending != pending {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel is in any other state
|
|
||||||
// than Default, then it means it is
|
|
||||||
// waiting to be closed.
|
|
||||||
channelWaitingClose :=
|
|
||||||
channel.ChanStatus() != ChanStatusDefault
|
|
||||||
|
|
||||||
// Only include it if we requested
|
|
||||||
// channels with the same waitingClose
|
|
||||||
// status.
|
|
||||||
if channelWaitingClose != waitingClose {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
channels = append(channels, channel)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return channels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchClosedChannels attempts to fetch all closed channels from the database.
|
// FetchClosedChannels attempts to fetch all closed channels from the database.
|
||||||
// The pendingOnly bool toggles if channels that aren't yet fully closed should
|
// The pendingOnly bool toggles if channels that aren't yet fully closed should
|
||||||
// be returned in the response or not. When a channel was cooperatively closed,
|
// be returned in the response or not. When a channel was cooperatively closed,
|
||||||
@ -641,371 +220,6 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro
|
|||||||
return chanSummaries, nil
|
return chanSummaries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrClosedChannelNotFound signals that a closed channel could not be found in
|
|
||||||
// the channeldb.
|
|
||||||
var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary")
|
|
||||||
|
|
||||||
// FetchClosedChannel queries for a channel close summary using the channel
|
|
||||||
// point of the channel in question.
|
|
||||||
func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) {
|
|
||||||
var chanSummary *ChannelCloseSummary
|
|
||||||
if err := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
closeBucket := tx.Bucket(closedChannelBucket)
|
|
||||||
if closeBucket == nil {
|
|
||||||
return ErrClosedChannelNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
var err error
|
|
||||||
if err = writeOutpoint(&b, chanID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
summaryBytes := closeBucket.Get(b.Bytes())
|
|
||||||
if summaryBytes == nil {
|
|
||||||
return ErrClosedChannelNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
summaryReader := bytes.NewReader(summaryBytes)
|
|
||||||
chanSummary, err = deserializeCloseChannelSummary(summaryReader)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanSummary, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchClosedChannelForID queries for a channel close summary using the
|
|
||||||
// channel ID of the channel in question.
|
|
||||||
func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) (
|
|
||||||
*ChannelCloseSummary, error) {
|
|
||||||
|
|
||||||
var chanSummary *ChannelCloseSummary
|
|
||||||
if err := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
closeBucket := tx.Bucket(closedChannelBucket)
|
|
||||||
if closeBucket == nil {
|
|
||||||
return ErrClosedChannelNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// The first 30 bytes of the channel ID and outpoint will be
|
|
||||||
// equal.
|
|
||||||
cursor := closeBucket.Cursor()
|
|
||||||
op, c := cursor.Seek(cid[:30])
|
|
||||||
|
|
||||||
// We scan over all possible candidates for this channel ID.
|
|
||||||
for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() {
|
|
||||||
var outPoint wire.OutPoint
|
|
||||||
err := readOutpoint(bytes.NewReader(op), &outPoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the found outpoint does not correspond to this
|
|
||||||
// channel ID, we continue.
|
|
||||||
if !cid.IsChanPoint(&outPoint) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deserialize the close summary and return.
|
|
||||||
r := bytes.NewReader(c)
|
|
||||||
chanSummary, err = deserializeCloseChannelSummary(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return ErrClosedChannelNotFound
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanSummary, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkChanFullyClosed marks a channel as fully closed within the database. A
|
|
||||||
// channel should be marked as fully closed if the channel was initially
|
|
||||||
// cooperatively closed and it's reached a single confirmation, or after all
|
|
||||||
// the pending funds in a channel that has been forcibly closed have been
|
|
||||||
// swept.
|
|
||||||
func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error {
|
|
||||||
return d.Update(func(tx *bbolt.Tx) error {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := writeOutpoint(&b, chanPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chanID := b.Bytes()
|
|
||||||
|
|
||||||
closedChanBucket, err := tx.CreateBucketIfNotExists(
|
|
||||||
closedChannelBucket,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chanSummaryBytes := closedChanBucket.Get(chanID)
|
|
||||||
if chanSummaryBytes == nil {
|
|
||||||
return fmt.Errorf("no closed channel for "+
|
|
||||||
"chan_point=%v found", chanPoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanSummaryReader := bytes.NewReader(chanSummaryBytes)
|
|
||||||
chanSummary, err := deserializeCloseChannelSummary(
|
|
||||||
chanSummaryReader,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chanSummary.IsPending = false
|
|
||||||
|
|
||||||
var newSummary bytes.Buffer
|
|
||||||
err = serializeChannelCloseSummary(&newSummary, chanSummary)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = closedChanBucket.Put(chanID, newSummary.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the channel is closed, we'll check if we have any
|
|
||||||
// other open channels with this peer. If we don't we'll
|
|
||||||
// garbage collect it to ensure we don't establish persistent
|
|
||||||
// connections to peers without open channels.
|
|
||||||
return d.pruneLinkNode(tx, chanSummary.RemotePub)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// pruneLinkNode determines whether we should garbage collect a link node from
|
|
||||||
// the database due to no longer having any open channels with it. If there are
|
|
||||||
// any left, then this acts as a no-op.
|
|
||||||
func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error {
|
|
||||||
openChannels, err := d.fetchOpenChannels(tx, remotePub)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to fetch open channels for peer %x: "+
|
|
||||||
"%v", remotePub.SerializeCompressed(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(openChannels) > 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Pruning link node %x with zero open channels from database",
|
|
||||||
remotePub.SerializeCompressed())
|
|
||||||
|
|
||||||
return d.deleteLinkNode(tx, remotePub)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PruneLinkNodes attempts to prune all link nodes found within the databse with
|
|
||||||
// whom we no longer have any open channels with.
|
|
||||||
func (d *DB) PruneLinkNodes() error {
|
|
||||||
return d.Update(func(tx *bbolt.Tx) error {
|
|
||||||
linkNodes, err := d.fetchAllLinkNodes(tx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, linkNode := range linkNodes {
|
|
||||||
err := d.pruneLinkNode(tx, linkNode.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelShell is a shell of a channel that is meant to be used for channel
|
|
||||||
// recovery purposes. It contains a minimal OpenChannel instance along with
|
|
||||||
// addresses for that target node.
|
|
||||||
type ChannelShell struct {
|
|
||||||
// NodeAddrs the set of addresses that this node has known to be
|
|
||||||
// reachable at in the past.
|
|
||||||
NodeAddrs []net.Addr
|
|
||||||
|
|
||||||
// Chan is a shell of an OpenChannel, it contains only the items
|
|
||||||
// required to restore the channel on disk.
|
|
||||||
Chan *OpenChannel
|
|
||||||
}
|
|
||||||
|
|
||||||
// RestoreChannelShells is a method that allows the caller to reconstruct the
|
|
||||||
// state of an OpenChannel from the ChannelShell. We'll attempt to write the
|
|
||||||
// new channel to disk, create a LinkNode instance with the passed node
|
|
||||||
// addresses, and finally create an edge within the graph for the channel as
|
|
||||||
// well. This method is idempotent, so repeated calls with the same set of
|
|
||||||
// channel shells won't modify the database after the initial call.
|
|
||||||
func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error {
|
|
||||||
chanGraph := d.ChannelGraph()
|
|
||||||
|
|
||||||
// TODO(conner): find way to do this w/o accessing internal members?
|
|
||||||
chanGraph.cacheMu.Lock()
|
|
||||||
defer chanGraph.cacheMu.Unlock()
|
|
||||||
|
|
||||||
var chansRestored []uint64
|
|
||||||
err := d.Update(func(tx *bbolt.Tx) error {
|
|
||||||
for _, channelShell := range channelShells {
|
|
||||||
channel := channelShell.Chan
|
|
||||||
|
|
||||||
// When we make a channel, we mark that the channel has
|
|
||||||
// been restored, this will signal to other sub-systems
|
|
||||||
// to not attempt to use the channel as if it was a
|
|
||||||
// regular one.
|
|
||||||
channel.chanStatus |= ChanStatusRestored
|
|
||||||
|
|
||||||
// First, we'll attempt to create a new open channel
|
|
||||||
// and link node for this channel. If the channel
|
|
||||||
// already exists, then in order to ensure this method
|
|
||||||
// is idempotent, we'll continue to the next step.
|
|
||||||
channel.Db = d
|
|
||||||
err := syncNewChannel(
|
|
||||||
tx, channel, channelShell.NodeAddrs,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll create an active edge in the graph
|
|
||||||
// database for this channel in order to restore our
|
|
||||||
// partial view of the network.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): if we restore *after* the channel
|
|
||||||
// has been closed on chain, then need to inform the
|
|
||||||
// router that it should try and prune these values as
|
|
||||||
// we can detect them
|
|
||||||
edgeInfo := ChannelEdgeInfo{
|
|
||||||
ChannelID: channel.ShortChannelID.ToUint64(),
|
|
||||||
ChainHash: channel.ChainHash,
|
|
||||||
ChannelPoint: channel.FundingOutpoint,
|
|
||||||
Capacity: channel.Capacity,
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
selfNode, err := chanGraph.sourceNode(nodes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Depending on which pub key is smaller, we'll assign
|
|
||||||
// our roles as "node1" and "node2".
|
|
||||||
chanPeer := channel.IdentityPub.SerializeCompressed()
|
|
||||||
selfIsSmaller := bytes.Compare(
|
|
||||||
selfNode.PubKeyBytes[:], chanPeer,
|
|
||||||
) == -1
|
|
||||||
if selfIsSmaller {
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], chanPeer)
|
|
||||||
} else {
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], chanPeer)
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the edge info shell constructed, we'll now add
|
|
||||||
// it to the graph.
|
|
||||||
err = chanGraph.addChannelEdge(tx, &edgeInfo)
|
|
||||||
if err != nil && err != ErrEdgeAlreadyExist {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similarly, we'll construct a channel edge shell and
|
|
||||||
// add that itself to the graph.
|
|
||||||
chanEdge := ChannelEdgePolicy{
|
|
||||||
ChannelID: edgeInfo.ChannelID,
|
|
||||||
LastUpdate: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// If their pubkey is larger, then we'll flip the
|
|
||||||
// direction bit to indicate that us, the "second" node
|
|
||||||
// is updating their policy.
|
|
||||||
if !selfIsSmaller {
|
|
||||||
chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = updateEdgePolicy(tx, &chanEdge)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chansRestored = append(chansRestored, edgeInfo.ChannelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chanid := range chansRestored {
|
|
||||||
chanGraph.rejectCache.remove(chanid)
|
|
||||||
chanGraph.chanCache.remove(chanid)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddrsForNode consults the graph and channel database for all addresses known
|
|
||||||
// to the passed node public key.
|
|
||||||
func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) {
|
|
||||||
var (
|
|
||||||
linkNode *LinkNode
|
|
||||||
graphNode LightningNode
|
|
||||||
)
|
|
||||||
|
|
||||||
dbErr := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
|
|
||||||
linkNode, err = fetchLinkNode(tx, nodePub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll also query the graph for this peer to see if they have
|
|
||||||
// any addresses that we don't currently have stored within the
|
|
||||||
// link node database.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
compressedPubKey := nodePub.SerializeCompressed()
|
|
||||||
graphNode, err = fetchLightningNode(nodes, compressedPubKey)
|
|
||||||
if err != nil && err != ErrGraphNodeNotFound {
|
|
||||||
// If the node isn't found, then that's OK, as we still
|
|
||||||
// have the link node data.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if dbErr != nil {
|
|
||||||
return nil, dbErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we have both sources of addrs for this node, we'll use a
|
|
||||||
// map to de-duplicate any addresses between the two sources, and
|
|
||||||
// produce a final list of the combined addrs.
|
|
||||||
addrs := make(map[string]net.Addr)
|
|
||||||
for _, addr := range linkNode.Addresses {
|
|
||||||
addrs[addr.String()] = addr
|
|
||||||
}
|
|
||||||
for _, addr := range graphNode.Addresses {
|
|
||||||
addrs[addr.String()] = addr
|
|
||||||
}
|
|
||||||
dedupedAddrs := make([]net.Addr, 0, len(addrs))
|
|
||||||
for _, addr := range addrs {
|
|
||||||
dedupedAddrs = append(dedupedAddrs, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return dedupedAddrs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// syncVersions function is used for safe db version synchronization. It
|
// syncVersions function is used for safe db version synchronization. It
|
||||||
// applies migration functions to the current database and recovers the
|
// applies migration functions to the current database and recovers the
|
||||||
// previous state of db if at least one error/panic appeared during migration.
|
// previous state of db if at least one error/panic appeared during migration.
|
||||||
|
@ -1,471 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io/ioutil"
|
|
||||||
"math"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
"github.com/btcsuite/btcutil"
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/keychain"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
"github.com/lightningnetwork/lnd/shachain"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOpenWithCreate(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// First, create a temporary directory to be used for the duration of
|
|
||||||
// this test.
|
|
||||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create temp dir: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDirName)
|
|
||||||
|
|
||||||
// Next, open thereby creating channeldb for the first time.
|
|
||||||
dbPath := filepath.Join(tempDirName, "cdb")
|
|
||||||
cdb, err := Open(dbPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channeldb: %v", err)
|
|
||||||
}
|
|
||||||
if err := cdb.Close(); err != nil {
|
|
||||||
t.Fatalf("unable to close channeldb: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The path should have been successfully created.
|
|
||||||
if !fileExists(dbPath) {
|
|
||||||
t.Fatalf("channeldb failed to create data directory")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestWipe tests that the database wipe operation completes successfully
|
|
||||||
// and that the buckets are deleted. It also checks that attempts to fetch
|
|
||||||
// information while the buckets are not set return the correct errors.
|
|
||||||
func TestWipe(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// First, create a temporary directory to be used for the duration of
|
|
||||||
// this test.
|
|
||||||
tempDirName, err := ioutil.TempDir("", "channeldb")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create temp dir: %v", err)
|
|
||||||
}
|
|
||||||
defer os.RemoveAll(tempDirName)
|
|
||||||
|
|
||||||
// Next, open thereby creating channeldb for the first time.
|
|
||||||
dbPath := filepath.Join(tempDirName, "cdb")
|
|
||||||
cdb, err := Open(dbPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channeldb: %v", err)
|
|
||||||
}
|
|
||||||
defer cdb.Close()
|
|
||||||
|
|
||||||
if err := cdb.Wipe(); err != nil {
|
|
||||||
t.Fatalf("unable to wipe channeldb: %v", err)
|
|
||||||
}
|
|
||||||
// Check correct errors are returned
|
|
||||||
_, err = cdb.FetchAllOpenChannels()
|
|
||||||
if err != ErrNoActiveChannels {
|
|
||||||
t.Fatalf("fetching open channels: expected '%v' instead got '%v'",
|
|
||||||
ErrNoActiveChannels, err)
|
|
||||||
}
|
|
||||||
_, err = cdb.FetchClosedChannels(false)
|
|
||||||
if err != ErrNoClosedChannels {
|
|
||||||
t.Fatalf("fetching closed channels: expected '%v' instead got '%v'",
|
|
||||||
ErrNoClosedChannels, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFetchClosedChannelForID tests that we are able to properly retrieve a
|
|
||||||
// ChannelCloseSummary from the DB given a ChannelID.
|
|
||||||
func TestFetchClosedChannelForID(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
const numChans = 101
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// Create the test channel state, that we will mutate the index of the
|
|
||||||
// funding point.
|
|
||||||
state, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now run through the number of channels, and modify the outpoint index
|
|
||||||
// to create new channel IDs.
|
|
||||||
for i := uint32(0); i < numChans; i++ {
|
|
||||||
// Save the open channel to disk.
|
|
||||||
state.FundingOutpoint.Index = i
|
|
||||||
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18556,
|
|
||||||
}
|
|
||||||
if err := state.SyncPending(addr, 101); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel "+
|
|
||||||
"state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close the channel. To make sure we retrieve the correct
|
|
||||||
// summary later, we make them differ in the SettledBalance.
|
|
||||||
closeSummary := &ChannelCloseSummary{
|
|
||||||
ChanPoint: state.FundingOutpoint,
|
|
||||||
RemotePub: state.IdentityPub,
|
|
||||||
SettledBalance: btcutil.Amount(500 + i),
|
|
||||||
}
|
|
||||||
if err := state.CloseChannel(closeSummary); err != nil {
|
|
||||||
t.Fatalf("unable to close channel: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now run though them all again and make sure we are able to retrieve
|
|
||||||
// summaries from the DB.
|
|
||||||
for i := uint32(0); i < numChans; i++ {
|
|
||||||
state.FundingOutpoint.Index = i
|
|
||||||
|
|
||||||
// We calculate the ChannelID and use it to fetch the summary.
|
|
||||||
cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint)
|
|
||||||
fetchedSummary, err := cdb.FetchClosedChannelForID(cid)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch close summary: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure we retrieved the correct one by checking the
|
|
||||||
// SettledBalance.
|
|
||||||
if fetchedSummary.SettledBalance != btcutil.Amount(500+i) {
|
|
||||||
t.Fatalf("summaries don't match: expected %v got %v",
|
|
||||||
btcutil.Amount(500+i),
|
|
||||||
fetchedSummary.SettledBalance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// As a final test we make sure that we get ErrClosedChannelNotFound
|
|
||||||
// for a ChannelID we didn't add to the DB.
|
|
||||||
state.FundingOutpoint.Index++
|
|
||||||
cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint)
|
|
||||||
_, err = cdb.FetchClosedChannelForID(cid)
|
|
||||||
if err != ErrClosedChannelNotFound {
|
|
||||||
t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddrsForNode tests the we're able to properly obtain all the addresses
|
|
||||||
// for a target node.
|
|
||||||
func TestAddrsForNode(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
graph := cdb.ChannelGraph()
|
|
||||||
|
|
||||||
// We'll make a test vertex to insert into the database, as the source
|
|
||||||
// node, but this node will only have half the number of addresses it
|
|
||||||
// usually does.
|
|
||||||
testNode, err := createTestVertex(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
testNode.Addresses = []net.Addr{testAddr}
|
|
||||||
if err := graph.SetSourceNode(testNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll make a link node with the same pubkey, but with an
|
|
||||||
// additional address.
|
|
||||||
nodePub, err := testNode.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to recv node pub: %v", err)
|
|
||||||
}
|
|
||||||
linkNode := cdb.NewLinkNode(
|
|
||||||
wire.MainNet, nodePub, anotherAddr,
|
|
||||||
)
|
|
||||||
if err := linkNode.Sync(); err != nil {
|
|
||||||
t.Fatalf("unable to sync link node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we've created a link node, as well as a vertex for the
|
|
||||||
// node, we'll query for all its addresses.
|
|
||||||
nodeAddrs, err := cdb.AddrsForNode(nodePub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to obtain node addrs: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedAddrs := make(map[string]struct{})
|
|
||||||
expectedAddrs[testAddr.String()] = struct{}{}
|
|
||||||
expectedAddrs[anotherAddr.String()] = struct{}{}
|
|
||||||
|
|
||||||
// Finally, ensure that all the expected addresses are found.
|
|
||||||
if len(nodeAddrs) != len(expectedAddrs) {
|
|
||||||
t.Fatalf("expected %v addrs, got %v",
|
|
||||||
len(expectedAddrs), len(nodeAddrs))
|
|
||||||
}
|
|
||||||
for _, addr := range nodeAddrs {
|
|
||||||
if _, ok := expectedAddrs[addr.String()]; !ok {
|
|
||||||
t.Fatalf("unexpected addr: %v", addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFetchChannel tests that we're able to fetch an arbitrary channel from
|
|
||||||
// disk.
|
|
||||||
func TestFetchChannel(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// Create the test channel state that we'll sync to the database
|
|
||||||
// shortly.
|
|
||||||
channelState, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark the channel as pending, then immediately mark it as open to it
|
|
||||||
// can be fully visible.
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18555,
|
|
||||||
}
|
|
||||||
if err := channelState.SyncPending(addr, 9); err != nil {
|
|
||||||
t.Fatalf("unable to save and serialize channel state: %v", err)
|
|
||||||
}
|
|
||||||
err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to mark channel open: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, attempt to fetch the channel by its chan point.
|
|
||||||
dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The decoded channel state should be identical to what we stored
|
|
||||||
// above.
|
|
||||||
if !reflect.DeepEqual(channelState, dbChannel) {
|
|
||||||
t.Fatalf("channel state doesn't match:: %v vs %v",
|
|
||||||
spew.Sdump(channelState), spew.Sdump(dbChannel))
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we attempt to query for a non-exist ante channel, then we should
|
|
||||||
// get an error.
|
|
||||||
channelState2, err := createTestChannelState(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create channel state: %v", err)
|
|
||||||
}
|
|
||||||
channelState2.FundingOutpoint.Index ^= 1
|
|
||||||
|
|
||||||
_, err = cdb.FetchChannel(channelState2.FundingOutpoint)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected query to fail")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func genRandomChannelShell() (*ChannelShell, error) {
|
|
||||||
var testPriv [32]byte
|
|
||||||
if _, err := rand.Read(testPriv[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:])
|
|
||||||
|
|
||||||
var chanPoint wire.OutPoint
|
|
||||||
if _, err := rand.Read(chanPoint.Hash[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
pub.Curve = nil
|
|
||||||
|
|
||||||
chanPoint.Index = uint32(rand.Intn(math.MaxUint16))
|
|
||||||
|
|
||||||
chanStatus := ChanStatusDefault | ChanStatusRestored
|
|
||||||
|
|
||||||
var shaChainPriv [32]byte
|
|
||||||
if _, err := rand.Read(testPriv[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
revRoot, err := chainhash.NewHash(shaChainPriv[:])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
shaChainProducer := shachain.NewRevocationProducer(*revRoot)
|
|
||||||
|
|
||||||
return &ChannelShell{
|
|
||||||
NodeAddrs: []net.Addr{&net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 18555,
|
|
||||||
}},
|
|
||||||
Chan: &OpenChannel{
|
|
||||||
chanStatus: chanStatus,
|
|
||||||
ChainHash: rev,
|
|
||||||
FundingOutpoint: chanPoint,
|
|
||||||
ShortChannelID: lnwire.NewShortChanIDFromInt(
|
|
||||||
uint64(rand.Int63()),
|
|
||||||
),
|
|
||||||
IdentityPub: pub,
|
|
||||||
LocalChanCfg: ChannelConfig{
|
|
||||||
ChannelConstraints: ChannelConstraints{
|
|
||||||
CsvDelay: uint16(rand.Int63()),
|
|
||||||
},
|
|
||||||
PaymentBasePoint: keychain.KeyDescriptor{
|
|
||||||
KeyLocator: keychain.KeyLocator{
|
|
||||||
Family: keychain.KeyFamily(rand.Int63()),
|
|
||||||
Index: uint32(rand.Int63()),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
RemoteCurrentRevocation: pub,
|
|
||||||
IsPending: false,
|
|
||||||
RevocationStore: shachain.NewRevocationStore(),
|
|
||||||
RevocationProducer: shaChainProducer,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRestoreChannelShells tests that we're able to insert a partially channel
|
|
||||||
// populated to disk. This is useful for channel recovery purposes. We should
|
|
||||||
// find the new channel shell on disk, and also the db should be populated with
|
|
||||||
// an edge for that channel.
|
|
||||||
func TestRestoreChannelShells(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// First, we'll make our channel shell, it will only have the minimal
|
|
||||||
// amount of information required for us to initiate the data loss
|
|
||||||
// protection feature.
|
|
||||||
channelShell, err := genRandomChannelShell()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to gen channel shell: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := cdb.ChannelGraph()
|
|
||||||
|
|
||||||
// Before we can restore the channel, we'll need to make a source node
|
|
||||||
// in the graph as the channel edge we create will need to have a
|
|
||||||
// origin.
|
|
||||||
testNode, err := createTestVertex(cdb)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.SetSourceNode(testNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the channel shell constructed, we'll now insert it into the
|
|
||||||
// database with the restoration method.
|
|
||||||
if err := cdb.RestoreChannelShells(channelShell); err != nil {
|
|
||||||
t.Fatalf("unable to restore channel shell: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the channel has been inserted, we'll attempt to query for
|
|
||||||
// it to ensure we can properly locate it via various means.
|
|
||||||
//
|
|
||||||
// First, we'll attempt to query for all channels that we have with the
|
|
||||||
// node public key that was restored.
|
|
||||||
nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable find channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should now find a single channel from the database.
|
|
||||||
if len(nodeChans) != 1 {
|
|
||||||
t.Fatalf("unable to find restored channel by node "+
|
|
||||||
"pubkey: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that it isn't possible to modify the commitment state machine
|
|
||||||
// of this restored channel.
|
|
||||||
channel := nodeChans[0]
|
|
||||||
err = channel.UpdateCommitment(nil)
|
|
||||||
if err != ErrNoRestoredChannelMutation {
|
|
||||||
t.Fatalf("able to mutate restored channel")
|
|
||||||
}
|
|
||||||
err = channel.AppendRemoteCommitChain(nil)
|
|
||||||
if err != ErrNoRestoredChannelMutation {
|
|
||||||
t.Fatalf("able to mutate restored channel")
|
|
||||||
}
|
|
||||||
err = channel.AdvanceCommitChainTail(nil)
|
|
||||||
if err != ErrNoRestoredChannelMutation {
|
|
||||||
t.Fatalf("able to mutate restored channel")
|
|
||||||
}
|
|
||||||
|
|
||||||
// That single channel should have the proper channel point, and also
|
|
||||||
// the expected set of flags to indicate that it was a restored
|
|
||||||
// channel.
|
|
||||||
if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint {
|
|
||||||
t.Fatalf("wrong funding outpoint: expected %v, got %v",
|
|
||||||
nodeChans[0].FundingOutpoint,
|
|
||||||
channelShell.Chan.FundingOutpoint)
|
|
||||||
}
|
|
||||||
if !nodeChans[0].HasChanStatus(ChanStatusRestored) {
|
|
||||||
t.Fatalf("node has wrong status flags: %v",
|
|
||||||
nodeChans[0].chanStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should also be able to find the channel if we query for it
|
|
||||||
// directly.
|
|
||||||
_, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should also be able to find the link node that was inserted by
|
|
||||||
// its public key.
|
|
||||||
linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch link node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node should have the same address, as specified in the channel
|
|
||||||
// shell.
|
|
||||||
if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) {
|
|
||||||
t.Fatalf("addr mismach: expected %v, got %v",
|
|
||||||
linkNode.Addresses, channelShell.NodeAddrs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll ensure that the edge for the channel was properly
|
|
||||||
// inserted.
|
|
||||||
chanInfos, err := graph.FetchChanInfos(
|
|
||||||
[]uint64{channelShell.Chan.ShortChannelID.ToUint64()},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to find edges: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(chanInfos) != 1 {
|
|
||||||
t.Fatalf("wrong amount of chan infos: expected %v got %v",
|
|
||||||
len(chanInfos), 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should only find a single edge.
|
|
||||||
if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil {
|
|
||||||
t.Fatalf("only a single edge should be inserted: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
@ -1,55 +1,23 @@
|
|||||||
package migration_01_to_11
|
package migration_01_to_11
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// ErrNoChanDBExists is returned when a channel bucket hasn't been
|
|
||||||
// created.
|
|
||||||
ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created")
|
|
||||||
|
|
||||||
// ErrDBReversion is returned when detecting an attempt to revert to a
|
// ErrDBReversion is returned when detecting an attempt to revert to a
|
||||||
// prior database version.
|
// prior database version.
|
||||||
ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version")
|
ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version")
|
||||||
|
|
||||||
// ErrLinkNodesNotFound is returned when node info bucket hasn't been
|
|
||||||
// created.
|
|
||||||
ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist")
|
|
||||||
|
|
||||||
// ErrNoActiveChannels is returned when there is no active (open)
|
|
||||||
// channels within the database.
|
|
||||||
ErrNoActiveChannels = fmt.Errorf("no active channels exist")
|
|
||||||
|
|
||||||
// ErrNoPastDeltas is returned when the channel delta bucket hasn't been
|
|
||||||
// created.
|
|
||||||
ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas")
|
|
||||||
|
|
||||||
// ErrInvoiceNotFound is returned when a targeted invoice can't be
|
|
||||||
// found.
|
|
||||||
ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice")
|
|
||||||
|
|
||||||
// ErrNoInvoicesCreated is returned when we don't have invoices in
|
// ErrNoInvoicesCreated is returned when we don't have invoices in
|
||||||
// our database to return.
|
// our database to return.
|
||||||
ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices")
|
ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices")
|
||||||
|
|
||||||
// ErrDuplicateInvoice is returned when an invoice with the target
|
|
||||||
// payment hash already exists.
|
|
||||||
ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists")
|
|
||||||
|
|
||||||
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been
|
// ErrNoPaymentsCreated is returned when bucket of payments hasn't been
|
||||||
// created.
|
// created.
|
||||||
ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments")
|
ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments")
|
||||||
|
|
||||||
// ErrNodeNotFound is returned when node bucket exists, but node with
|
|
||||||
// specific identity can't be found.
|
|
||||||
ErrNodeNotFound = fmt.Errorf("link node with target identity not found")
|
|
||||||
|
|
||||||
// ErrChannelNotFound is returned when we attempt to locate a channel
|
|
||||||
// for a specific chain, but it is not found.
|
|
||||||
ErrChannelNotFound = fmt.Errorf("channel not found")
|
|
||||||
|
|
||||||
// ErrMetaNotFound is returned when meta bucket hasn't been
|
// ErrMetaNotFound is returned when meta bucket hasn't been
|
||||||
// created.
|
// created.
|
||||||
ErrMetaNotFound = fmt.Errorf("unable to locate meta information")
|
ErrMetaNotFound = fmt.Errorf("unable to locate meta information")
|
||||||
@ -58,22 +26,11 @@ var (
|
|||||||
// graph doesn't exist.
|
// graph doesn't exist.
|
||||||
ErrGraphNotFound = fmt.Errorf("graph bucket not initialized")
|
ErrGraphNotFound = fmt.Errorf("graph bucket not initialized")
|
||||||
|
|
||||||
// ErrGraphNeverPruned is returned when graph was never pruned.
|
|
||||||
ErrGraphNeverPruned = fmt.Errorf("graph never pruned")
|
|
||||||
|
|
||||||
// ErrSourceNodeNotSet is returned if the source node of the graph
|
// ErrSourceNodeNotSet is returned if the source node of the graph
|
||||||
// hasn't been added The source node is the center node within a
|
// hasn't been added The source node is the center node within a
|
||||||
// star-graph.
|
// star-graph.
|
||||||
ErrSourceNodeNotSet = fmt.Errorf("source node does not exist")
|
ErrSourceNodeNotSet = fmt.Errorf("source node does not exist")
|
||||||
|
|
||||||
// ErrGraphNodesNotFound is returned in case none of the nodes has
|
|
||||||
// been added in graph node bucket.
|
|
||||||
ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist")
|
|
||||||
|
|
||||||
// ErrGraphNoEdgesFound is returned in case of none of the channel/edges
|
|
||||||
// has been added in graph edge bucket.
|
|
||||||
ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist")
|
|
||||||
|
|
||||||
// ErrGraphNodeNotFound is returned when we're unable to find the target
|
// ErrGraphNodeNotFound is returned when we're unable to find the target
|
||||||
// node.
|
// node.
|
||||||
ErrGraphNodeNotFound = fmt.Errorf("unable to find node")
|
ErrGraphNodeNotFound = fmt.Errorf("unable to find node")
|
||||||
@ -82,17 +39,6 @@ var (
|
|||||||
// can't be found.
|
// can't be found.
|
||||||
ErrEdgeNotFound = fmt.Errorf("edge not found")
|
ErrEdgeNotFound = fmt.Errorf("edge not found")
|
||||||
|
|
||||||
// ErrZombieEdge is an error returned when we attempt to look up an edge
|
|
||||||
// but it is marked as a zombie within the zombie index.
|
|
||||||
ErrZombieEdge = errors.New("edge marked as zombie")
|
|
||||||
|
|
||||||
// ErrEdgeAlreadyExist is returned when edge with specific
|
|
||||||
// channel id can't be added because it already exist.
|
|
||||||
ErrEdgeAlreadyExist = fmt.Errorf("edge already exist")
|
|
||||||
|
|
||||||
// ErrNodeAliasNotFound is returned when alias for node can't be found.
|
|
||||||
ErrNodeAliasNotFound = fmt.Errorf("alias for node not found")
|
|
||||||
|
|
||||||
// ErrUnknownAddressType is returned when a node's addressType is not
|
// ErrUnknownAddressType is returned when a node's addressType is not
|
||||||
// an expected value.
|
// an expected value.
|
||||||
ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved")
|
ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved")
|
||||||
@ -101,20 +47,11 @@ var (
|
|||||||
// channels it has closed, but it hasn't yet closed any channels.
|
// channels it has closed, but it hasn't yet closed any channels.
|
||||||
ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet")
|
ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet")
|
||||||
|
|
||||||
// ErrNoForwardingEvents is returned in the case that a query fails due
|
|
||||||
// to the log not having any recorded events.
|
|
||||||
ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events")
|
|
||||||
|
|
||||||
// ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel
|
// ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel
|
||||||
// policy field is not found in the db even though its message flags
|
// policy field is not found in the db even though its message flags
|
||||||
// indicate it should be.
|
// indicate it should be.
|
||||||
ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " +
|
ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " +
|
||||||
"present")
|
"present")
|
||||||
|
|
||||||
// ErrChanAlreadyExists is return when the caller attempts to create a
|
|
||||||
// channel with a channel point that is already present in the
|
|
||||||
// database.
|
|
||||||
ErrChanAlreadyExists = fmt.Errorf("channel already exists")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the
|
// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the
|
||||||
|
@ -1 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
@ -1,274 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"sort"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// forwardingLogBucket is the bucket that we'll use to store the
|
|
||||||
// forwarding log. The forwarding log contains a time series database
|
|
||||||
// of the forwarding history of a lightning daemon. Each key within the
|
|
||||||
// bucket is a timestamp (in nano seconds since the unix epoch), and
|
|
||||||
// the value a slice of a forwarding event for that timestamp.
|
|
||||||
forwardingLogBucket = []byte("circuit-fwd-log")
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// forwardingEventSize is the size of a forwarding event. The breakdown
|
|
||||||
// is as follows:
|
|
||||||
//
|
|
||||||
// * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in
|
|
||||||
// || 8 byte value out
|
|
||||||
//
|
|
||||||
// From the value in and value out, callers can easily compute the
|
|
||||||
// total fee extract from a forwarding event.
|
|
||||||
forwardingEventSize = 32
|
|
||||||
|
|
||||||
// MaxResponseEvents is the max number of forwarding events that will
|
|
||||||
// be returned by a single query response. This size was selected to
|
|
||||||
// safely remain under gRPC's 4MiB message size response limit. As each
|
|
||||||
// full forwarding event (including the timestamp) is 40 bytes, we can
|
|
||||||
// safely return 50k entries in a single response.
|
|
||||||
MaxResponseEvents = 50000
|
|
||||||
)
|
|
||||||
|
|
||||||
// ForwardingLog returns an instance of the ForwardingLog object backed by the
|
|
||||||
// target database instance.
|
|
||||||
func (d *DB) ForwardingLog() *ForwardingLog {
|
|
||||||
return &ForwardingLog{
|
|
||||||
db: d,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardingLog is a time series database that logs the fulfilment of payment
|
|
||||||
// circuits by a lightning network daemon. The log contains a series of
|
|
||||||
// forwarding events which map a timestamp to a forwarding event. A forwarding
|
|
||||||
// event describes which channels were used to create+settle a circuit, and the
|
|
||||||
// amount involved. Subtracting the outgoing amount from the incoming amount
|
|
||||||
// reveals the fee charged for the forwarding service.
|
|
||||||
type ForwardingLog struct {
|
|
||||||
db *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardingEvent is an event in the forwarding log's time series. Each
|
|
||||||
// forwarding event logs the creation and tear-down of a payment circuit. A
|
|
||||||
// circuit is created once an incoming HTLC has been fully forwarded, and
|
|
||||||
// destroyed once the payment has been settled.
|
|
||||||
type ForwardingEvent struct {
|
|
||||||
// Timestamp is the settlement time of this payment circuit.
|
|
||||||
Timestamp time.Time
|
|
||||||
|
|
||||||
// IncomingChanID is the incoming channel ID of the payment circuit.
|
|
||||||
IncomingChanID lnwire.ShortChannelID
|
|
||||||
|
|
||||||
// OutgoingChanID is the outgoing channel ID of the payment circuit.
|
|
||||||
OutgoingChanID lnwire.ShortChannelID
|
|
||||||
|
|
||||||
// AmtIn is the amount of the incoming HTLC. Subtracting this from the
|
|
||||||
// outgoing amount gives the total fees of this payment circuit.
|
|
||||||
AmtIn lnwire.MilliSatoshi
|
|
||||||
|
|
||||||
// AmtOut is the amount of the outgoing HTLC. Subtracting the incoming
|
|
||||||
// amount from this gives the total fees for this payment circuit.
|
|
||||||
AmtOut lnwire.MilliSatoshi
|
|
||||||
}
|
|
||||||
|
|
||||||
// encodeForwardingEvent writes out the target forwarding event to the passed
|
|
||||||
// io.Writer, using the expected DB format. Note that the timestamp isn't
|
|
||||||
// serialized as this will be the key value within the bucket.
|
|
||||||
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
|
|
||||||
return WriteElements(
|
|
||||||
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decodeForwardingEvent attempts to decode the raw bytes of a serialized
|
|
||||||
// forwarding event into the target ForwardingEvent. Note that the timestamp
|
|
||||||
// won't be decoded, as the caller is expected to set this due to the bucket
|
|
||||||
// structure of the forwarding log.
|
|
||||||
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
|
|
||||||
return ReadElements(
|
|
||||||
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddForwardingEvents adds a series of forwarding events to the database.
|
|
||||||
// Before inserting, the set of events will be sorted according to their
|
|
||||||
// timestamp. This ensures that all writes to disk are sequential.
|
|
||||||
func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error {
|
|
||||||
// Before we create the database transaction, we'll ensure that the set
|
|
||||||
// of forwarding events are properly sorted according to their
|
|
||||||
// timestamp.
|
|
||||||
sort.Slice(events, func(i, j int) bool {
|
|
||||||
return events[i].Timestamp.Before(events[j].Timestamp)
|
|
||||||
})
|
|
||||||
|
|
||||||
var timestamp [8]byte
|
|
||||||
|
|
||||||
return f.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
// First, we'll fetch the bucket that stores our time series
|
|
||||||
// log.
|
|
||||||
logBucket, err := tx.CreateBucketIfNotExists(
|
|
||||||
forwardingLogBucket,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the bucket obtained, we can now begin to write out the
|
|
||||||
// series of events.
|
|
||||||
for _, event := range events {
|
|
||||||
var eventBytes [forwardingEventSize]byte
|
|
||||||
eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize])
|
|
||||||
|
|
||||||
// First, we'll serialize this timestamp into our
|
|
||||||
// timestamp buffer.
|
|
||||||
byteOrder.PutUint64(
|
|
||||||
timestamp[:], uint64(event.Timestamp.UnixNano()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// With the key encoded, we'll then encode the event
|
|
||||||
// into our buffer, then write it out to disk.
|
|
||||||
err := encodeForwardingEvent(eventBuf, &event)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = logBucket.Put(timestamp[:], eventBuf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardingEventQuery represents a query to the forwarding log payment
|
|
||||||
// circuit time series database. The query allows a caller to retrieve all
|
|
||||||
// records for a particular time slice, offset in that time slice, limiting the
|
|
||||||
// total number of responses returned.
|
|
||||||
type ForwardingEventQuery struct {
|
|
||||||
// StartTime is the start time of the time slice.
|
|
||||||
StartTime time.Time
|
|
||||||
|
|
||||||
// EndTime is the end time of the time slice.
|
|
||||||
EndTime time.Time
|
|
||||||
|
|
||||||
// IndexOffset is the offset within the time slice to start at. This
|
|
||||||
// can be used to start the response at a particular record.
|
|
||||||
IndexOffset uint32
|
|
||||||
|
|
||||||
// NumMaxEvents is the max number of events to return.
|
|
||||||
NumMaxEvents uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardingLogTimeSlice is the response to a forwarding query. It includes
|
|
||||||
// the original query, the set events that match the query, and an integer
|
|
||||||
// which represents the offset index of the last item in the set of retuned
|
|
||||||
// events. This integer allows callers to resume their query using this offset
|
|
||||||
// in the event that the query's response exceeds the max number of returnable
|
|
||||||
// events.
|
|
||||||
type ForwardingLogTimeSlice struct {
|
|
||||||
ForwardingEventQuery
|
|
||||||
|
|
||||||
// ForwardingEvents is the set of events in our time series that answer
|
|
||||||
// the query embedded above.
|
|
||||||
ForwardingEvents []ForwardingEvent
|
|
||||||
|
|
||||||
// LastIndexOffset is the index of the last element in the set of
|
|
||||||
// returned ForwardingEvents above. Callers can use this to resume
|
|
||||||
// their query in the event that the time slice has too many events to
|
|
||||||
// fit into a single response.
|
|
||||||
LastIndexOffset uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query allows a caller to query the forwarding event time series for a
|
|
||||||
// particular time slice. The caller can control the precise time as well as
|
|
||||||
// the number of events to be returned.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): rename?
|
|
||||||
func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) {
|
|
||||||
resp := ForwardingLogTimeSlice{
|
|
||||||
ForwardingEventQuery: q,
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the user provided an index offset, then we'll not know how many
|
|
||||||
// records we need to skip. We'll also keep track of the record offset
|
|
||||||
// as that's part of the final return value.
|
|
||||||
recordsToSkip := q.IndexOffset
|
|
||||||
recordOffset := q.IndexOffset
|
|
||||||
|
|
||||||
err := f.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// If the bucket wasn't found, then there aren't any events to
|
|
||||||
// be returned.
|
|
||||||
logBucket := tx.Bucket(forwardingLogBucket)
|
|
||||||
if logBucket == nil {
|
|
||||||
return ErrNoForwardingEvents
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll be using a cursor to seek into the database, so we'll
|
|
||||||
// populate byte slices that represent the start of the key
|
|
||||||
// space we're interested in, and the end.
|
|
||||||
var startTime, endTime [8]byte
|
|
||||||
byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano()))
|
|
||||||
byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano()))
|
|
||||||
|
|
||||||
// If we know that a set of log events exists, then we'll begin
|
|
||||||
// our seek through the log in order to satisfy the query.
|
|
||||||
// We'll continue until either we reach the end of the range,
|
|
||||||
// or reach our max number of events.
|
|
||||||
logCursor := logBucket.Cursor()
|
|
||||||
timestamp, events := logCursor.Seek(startTime[:])
|
|
||||||
for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() {
|
|
||||||
// If our current return payload exceeds the max number
|
|
||||||
// of events, then we'll exit now.
|
|
||||||
if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we're not yet past the user defined offset, then
|
|
||||||
// we'll continue to seek forward.
|
|
||||||
if recordsToSkip > 0 {
|
|
||||||
recordsToSkip--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
currentTime := time.Unix(
|
|
||||||
0, int64(byteOrder.Uint64(timestamp)),
|
|
||||||
)
|
|
||||||
|
|
||||||
// At this point, we've skipped enough records to start
|
|
||||||
// to collate our query. For each record, we'll
|
|
||||||
// increment the final record offset so the querier can
|
|
||||||
// utilize pagination to seek further.
|
|
||||||
readBuf := bytes.NewReader(events)
|
|
||||||
for readBuf.Len() != 0 {
|
|
||||||
var event ForwardingEvent
|
|
||||||
err := decodeForwardingEvent(readBuf, &event)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
event.Timestamp = currentTime
|
|
||||||
resp.ForwardingEvents = append(resp.ForwardingEvents, event)
|
|
||||||
|
|
||||||
recordOffset++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil && err != ErrNoForwardingEvents {
|
|
||||||
return ForwardingLogTimeSlice{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.LastIndexOffset = recordOffset
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
@ -1,265 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestForwardingLogBasicStorageAndQuery tests that we're able to store and
|
|
||||||
// then query for items that have previously been added to the event log.
|
|
||||||
func TestForwardingLogBasicStorageAndQuery(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// First, we'll set up a test database, and use that to instantiate the
|
|
||||||
// forwarding event log that we'll be using for the duration of the
|
|
||||||
// test.
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
log := ForwardingLog{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
|
|
||||||
initialTime := time.Unix(1234, 0)
|
|
||||||
timestamp := time.Unix(1234, 0)
|
|
||||||
|
|
||||||
// We'll create 100 random events, which each event being spaced 10
|
|
||||||
// minutes after the prior event.
|
|
||||||
numEvents := 100
|
|
||||||
events := make([]ForwardingEvent, numEvents)
|
|
||||||
for i := 0; i < numEvents; i++ {
|
|
||||||
events[i] = ForwardingEvent{
|
|
||||||
Timestamp: timestamp,
|
|
||||||
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
|
||||||
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
|
||||||
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
|
|
||||||
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
|
|
||||||
}
|
|
||||||
|
|
||||||
timestamp = timestamp.Add(time.Minute * 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that all of our set of events constructed, we'll add them to the
|
|
||||||
// database in a batch manner.
|
|
||||||
if err := log.AddForwardingEvents(events); err != nil {
|
|
||||||
t.Fatalf("unable to add events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With our events added we'll now construct a basic query to retrieve
|
|
||||||
// all of the events.
|
|
||||||
eventQuery := ForwardingEventQuery{
|
|
||||||
StartTime: initialTime,
|
|
||||||
EndTime: timestamp,
|
|
||||||
IndexOffset: 0,
|
|
||||||
NumMaxEvents: 1000,
|
|
||||||
}
|
|
||||||
timeSlice, err := log.Query(eventQuery)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The set of returned events should match identically, as they should
|
|
||||||
// be returned in sorted order.
|
|
||||||
if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) {
|
|
||||||
t.Fatalf("event mismatch: expected %v vs %v",
|
|
||||||
spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The offset index of the final entry should be numEvents, so the
|
|
||||||
// number of total events we've written.
|
|
||||||
if timeSlice.LastIndexOffset != uint32(numEvents) {
|
|
||||||
t.Fatalf("wrong final offset: expected %v, got %v",
|
|
||||||
timeSlice.LastIndexOffset, numEvents)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestForwardingLogQueryOptions tests that the query offset works properly. So
|
|
||||||
// if we add a series of events, then we should be able to seek within the
|
|
||||||
// timeslice accordingly. This exercises the index offset and num max event
|
|
||||||
// field in the query, and also the last index offset field int he response.
|
|
||||||
func TestForwardingLogQueryOptions(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// First, we'll set up a test database, and use that to instantiate the
|
|
||||||
// forwarding event log that we'll be using for the duration of the
|
|
||||||
// test.
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
log := ForwardingLog{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
|
|
||||||
initialTime := time.Unix(1234, 0)
|
|
||||||
endTime := time.Unix(1234, 0)
|
|
||||||
|
|
||||||
// We'll create 20 random events, which each event being spaced 10
|
|
||||||
// minutes after the prior event.
|
|
||||||
numEvents := 20
|
|
||||||
events := make([]ForwardingEvent, numEvents)
|
|
||||||
for i := 0; i < numEvents; i++ {
|
|
||||||
events[i] = ForwardingEvent{
|
|
||||||
Timestamp: endTime,
|
|
||||||
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
|
||||||
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
|
||||||
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
|
|
||||||
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
|
|
||||||
}
|
|
||||||
|
|
||||||
endTime = endTime.Add(time.Minute * 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that all of our set of events constructed, we'll add them to the
|
|
||||||
// database in a batch manner.
|
|
||||||
if err := log.AddForwardingEvents(events); err != nil {
|
|
||||||
t.Fatalf("unable to add events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With all of our events added, we should be able to query for the
|
|
||||||
// first 10 events using the max event query field.
|
|
||||||
eventQuery := ForwardingEventQuery{
|
|
||||||
StartTime: initialTime,
|
|
||||||
EndTime: endTime,
|
|
||||||
IndexOffset: 0,
|
|
||||||
NumMaxEvents: 10,
|
|
||||||
}
|
|
||||||
timeSlice, err := log.Query(eventQuery)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should get exactly 10 events back.
|
|
||||||
if len(timeSlice.ForwardingEvents) != 10 {
|
|
||||||
t.Fatalf("wrong number of events: expected %v, got %v", 10,
|
|
||||||
len(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The set of events returned should be the first 10 events that we
|
|
||||||
// added.
|
|
||||||
if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) {
|
|
||||||
t.Fatalf("wrong response: expected %v, got %v",
|
|
||||||
spew.Sdump(events[:10]),
|
|
||||||
spew.Sdump(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The final offset should be the exact number of events returned.
|
|
||||||
if timeSlice.LastIndexOffset != 10 {
|
|
||||||
t.Fatalf("wrong index offset: expected %v, got %v", 10,
|
|
||||||
timeSlice.LastIndexOffset)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we use the final offset to query again, then we should get 10
|
|
||||||
// more events, that are the last 10 events we wrote.
|
|
||||||
eventQuery.IndexOffset = 10
|
|
||||||
timeSlice, err = log.Query(eventQuery)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should get exactly 10 events back once again.
|
|
||||||
if len(timeSlice.ForwardingEvents) != 10 {
|
|
||||||
t.Fatalf("wrong number of events: expected %v, got %v", 10,
|
|
||||||
len(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The events that we got back should be the last 10 events that we
|
|
||||||
// wrote out.
|
|
||||||
if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) {
|
|
||||||
t.Fatalf("wrong response: expected %v, got %v",
|
|
||||||
spew.Sdump(events[10:]),
|
|
||||||
spew.Sdump(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, the last index offset should be 20, or the number of
|
|
||||||
// records we've written out.
|
|
||||||
if timeSlice.LastIndexOffset != 20 {
|
|
||||||
t.Fatalf("wrong index offset: expected %v, got %v", 20,
|
|
||||||
timeSlice.LastIndexOffset)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestForwardingLogQueryLimit tests that we're able to properly limit the
|
|
||||||
// number of events that are returned as part of a query.
|
|
||||||
func TestForwardingLogQueryLimit(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// First, we'll set up a test database, and use that to instantiate the
|
|
||||||
// forwarding event log that we'll be using for the duration of the
|
|
||||||
// test.
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
log := ForwardingLog{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
|
|
||||||
initialTime := time.Unix(1234, 0)
|
|
||||||
endTime := time.Unix(1234, 0)
|
|
||||||
|
|
||||||
// We'll create 200 random events, which each event being spaced 10
|
|
||||||
// minutes after the prior event.
|
|
||||||
numEvents := 200
|
|
||||||
events := make([]ForwardingEvent, numEvents)
|
|
||||||
for i := 0; i < numEvents; i++ {
|
|
||||||
events[i] = ForwardingEvent{
|
|
||||||
Timestamp: endTime,
|
|
||||||
IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
|
||||||
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
|
|
||||||
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
|
|
||||||
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
|
|
||||||
}
|
|
||||||
|
|
||||||
endTime = endTime.Add(time.Minute * 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that all of our set of events constructed, we'll add them to the
|
|
||||||
// database in a batch manner.
|
|
||||||
if err := log.AddForwardingEvents(events); err != nil {
|
|
||||||
t.Fatalf("unable to add events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once the events have been written out, we'll issue a query over the
|
|
||||||
// entire range, but restrict the number of events to the first 100.
|
|
||||||
eventQuery := ForwardingEventQuery{
|
|
||||||
StartTime: initialTime,
|
|
||||||
EndTime: endTime,
|
|
||||||
IndexOffset: 0,
|
|
||||||
NumMaxEvents: 100,
|
|
||||||
}
|
|
||||||
timeSlice, err := log.Query(eventQuery)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for events: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should get exactly 100 events back.
|
|
||||||
if len(timeSlice.ForwardingEvents) != 100 {
|
|
||||||
t.Fatalf("wrong number of events: expected %v, got %v", 10,
|
|
||||||
len(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The set of events returned should be the first 100 events that we
|
|
||||||
// added.
|
|
||||||
if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) {
|
|
||||||
t.Fatalf("wrong response: expected %v, got %v",
|
|
||||||
spew.Sdump(events[:100]),
|
|
||||||
spew.Sdump(timeSlice.ForwardingEvents))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The final offset should be the exact number of events returned.
|
|
||||||
if timeSlice.LastIndexOffset != 100 {
|
|
||||||
t.Fatalf("wrong index offset: expected %v, got %v", 100,
|
|
||||||
timeSlice.LastIndexOffset)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,928 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding
|
|
||||||
// package has potentially been mangled.
|
|
||||||
var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted")
|
|
||||||
|
|
||||||
// FwdState is an enum used to describe the lifecycle of a FwdPkg.
|
|
||||||
type FwdState byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
// FwdStateLockedIn is the starting state for all forwarding packages.
|
|
||||||
// Packages in this state have not yet committed to the exact set of
|
|
||||||
// Adds to forward to the switch.
|
|
||||||
FwdStateLockedIn FwdState = iota
|
|
||||||
|
|
||||||
// FwdStateProcessed marks the state in which all Adds have been
|
|
||||||
// locally processed and the forwarding decision to the switch has been
|
|
||||||
// persisted.
|
|
||||||
FwdStateProcessed
|
|
||||||
|
|
||||||
// FwdStateCompleted signals that all Adds have been acked, and that all
|
|
||||||
// settles and fails have been delivered to their sources. Packages in
|
|
||||||
// this state can be removed permanently.
|
|
||||||
FwdStateCompleted
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// fwdPackagesKey is the root-level bucket that all forwarding packages
|
|
||||||
// are written. This bucket is further subdivided based on the short
|
|
||||||
// channel ID of each channel.
|
|
||||||
fwdPackagesKey = []byte("fwd-packages")
|
|
||||||
|
|
||||||
// addBucketKey is the bucket to which all Add log updates are written.
|
|
||||||
addBucketKey = []byte("add-updates")
|
|
||||||
|
|
||||||
// failSettleBucketKey is the bucket to which all Settle/Fail log
|
|
||||||
// updates are written.
|
|
||||||
failSettleBucketKey = []byte("fail-settle-updates")
|
|
||||||
|
|
||||||
// fwdFilterKey is a key used to write the set of Adds that passed
|
|
||||||
// validation and are to be forwarded to the switch.
|
|
||||||
// NOTE: The presence of this key within a forwarding package indicates
|
|
||||||
// that the package has reached FwdStateProcessed.
|
|
||||||
fwdFilterKey = []byte("fwd-filter-key")
|
|
||||||
|
|
||||||
// ackFilterKey is a key used to access the PkgFilter indicating which
|
|
||||||
// Adds have received a Settle/Fail. This response may come from a
|
|
||||||
// number of sources, including: exitHop settle/fails, switch failures,
|
|
||||||
// chain arbiter interjections, as well as settle/fails from the
|
|
||||||
// next hop in the route.
|
|
||||||
ackFilterKey = []byte("ack-filter-key")
|
|
||||||
|
|
||||||
// settleFailFilterKey is a key used to access the PkgFilter indicating
|
|
||||||
// which Settles/Fails in have been received and processed by the link
|
|
||||||
// that originally received the Add.
|
|
||||||
settleFailFilterKey = []byte("settle-fail-filter-key")
|
|
||||||
)
|
|
||||||
|
|
||||||
// PkgFilter is used to compactly represent a particular subset of the Adds in a
|
|
||||||
// forwarding package. Each filter is represented as a simple, statically-sized
|
|
||||||
// bitvector, where the elements are intended to be the indices of the Adds as
|
|
||||||
// they are written in the FwdPkg.
|
|
||||||
type PkgFilter struct {
|
|
||||||
count uint16
|
|
||||||
filter []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPkgFilter initializes an empty PkgFilter supporting `count` elements.
|
|
||||||
func NewPkgFilter(count uint16) *PkgFilter {
|
|
||||||
// We add 7 to ensure that the integer division yields properly rounded
|
|
||||||
// values.
|
|
||||||
filterLen := (count + 7) / 8
|
|
||||||
|
|
||||||
return &PkgFilter{
|
|
||||||
count: count,
|
|
||||||
filter: make([]byte, filterLen),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count returns the number of elements represented by this PkgFilter.
|
|
||||||
func (f *PkgFilter) Count() uint16 {
|
|
||||||
return f.count
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set marks the `i`-th element as included by this filter.
|
|
||||||
// NOTE: It is assumed that i is always less than count.
|
|
||||||
func (f *PkgFilter) Set(i uint16) {
|
|
||||||
byt := i / 8
|
|
||||||
bit := i % 8
|
|
||||||
|
|
||||||
// Set the i-th bit in the filter.
|
|
||||||
// TODO(conner): ignore if > count to prevent panic?
|
|
||||||
f.filter[byt] |= byte(1 << (7 - bit))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Contains queries the filter for membership of index `i`.
|
|
||||||
// NOTE: It is assumed that i is always less than count.
|
|
||||||
func (f *PkgFilter) Contains(i uint16) bool {
|
|
||||||
byt := i / 8
|
|
||||||
bit := i % 8
|
|
||||||
|
|
||||||
// Read the i-th bit in the filter.
|
|
||||||
// TODO(conner): ignore if > count to prevent panic?
|
|
||||||
return f.filter[byt]&(1<<(7-bit)) != 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Equal checks two PkgFilters for equality.
|
|
||||||
func (f *PkgFilter) Equal(f2 *PkgFilter) bool {
|
|
||||||
if f == f2 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if f.count != f2.count {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return bytes.Equal(f.filter, f2.filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsFull returns true if every element in the filter has been Set, and false
|
|
||||||
// otherwise.
|
|
||||||
func (f *PkgFilter) IsFull() bool {
|
|
||||||
// Batch validate bytes that are fully used.
|
|
||||||
for i := uint16(0); i < f.count/8; i++ {
|
|
||||||
if f.filter[i] != 0xFF {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the count is not a multiple of 8, check that the filter contains
|
|
||||||
// all remaining bits.
|
|
||||||
rem := f.count % 8
|
|
||||||
for idx := f.count - rem; idx < f.count; idx++ {
|
|
||||||
if !f.Contains(idx) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Size returns number of bytes produced when the PkgFilter is serialized.
|
|
||||||
func (f *PkgFilter) Size() uint16 {
|
|
||||||
// 2 bytes for uint16 `count`, then round up number of bytes required to
|
|
||||||
// represent `count` bits.
|
|
||||||
return 2 + (f.count+7)/8
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode writes the filter to the provided io.Writer.
|
|
||||||
func (f *PkgFilter) Encode(w io.Writer) error {
|
|
||||||
if err := binary.Write(w, binary.BigEndian, f.count); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := w.Write(f.filter)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads the filter from the provided io.Reader.
|
|
||||||
func (f *PkgFilter) Decode(r io.Reader) error {
|
|
||||||
if err := binary.Read(r, binary.BigEndian, &f.count); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
f.filter = make([]byte, f.Size()-2)
|
|
||||||
_, err := io.ReadFull(r, f.filter)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// FwdPkg records all adds, settles, and fails that were locked in as a result
|
|
||||||
// of the remote peer sending us a revocation. Each package is identified by
|
|
||||||
// the short chanid and remote commitment height corresponding to the revocation
|
|
||||||
// that locked in the HTLCs. For everything except a locally initiated payment,
|
|
||||||
// settles and fails in a forwarding package must have a corresponding Add in
|
|
||||||
// another package, and can be removed individually once the source link has
|
|
||||||
// received the fail/settle.
|
|
||||||
//
|
|
||||||
// Adds cannot be removed, as we need to present the same batch of Adds to
|
|
||||||
// properly handle replay protection. Instead, we use a PkgFilter to mark that
|
|
||||||
// we have finished processing a particular Add. A FwdPkg should only be deleted
|
|
||||||
// after the AckFilter is full and all settles and fails have been persistently
|
|
||||||
// removed.
|
|
||||||
type FwdPkg struct {
|
|
||||||
// Source identifies the channel that wrote this forwarding package.
|
|
||||||
Source lnwire.ShortChannelID
|
|
||||||
|
|
||||||
// Height is the height of the remote commitment chain that locked in
|
|
||||||
// this forwarding package.
|
|
||||||
Height uint64
|
|
||||||
|
|
||||||
// State signals the persistent condition of the package and directs how
|
|
||||||
// to reprocess the package in the event of failures.
|
|
||||||
State FwdState
|
|
||||||
|
|
||||||
// Adds contains all add messages which need to be processed and
|
|
||||||
// forwarded to the switch. Adds does not change over the life of a
|
|
||||||
// forwarding package.
|
|
||||||
Adds []LogUpdate
|
|
||||||
|
|
||||||
// FwdFilter is a filter containing the indices of all Adds that were
|
|
||||||
// forwarded to the switch.
|
|
||||||
FwdFilter *PkgFilter
|
|
||||||
|
|
||||||
// AckFilter is a filter containing the indices of all Adds for which
|
|
||||||
// the source has received a settle or fail and is reflected in the next
|
|
||||||
// commitment txn. A package should not be removed until IsFull()
|
|
||||||
// returns true.
|
|
||||||
AckFilter *PkgFilter
|
|
||||||
|
|
||||||
// SettleFails contains all settle and fail messages that should be
|
|
||||||
// forwarded to the switch.
|
|
||||||
SettleFails []LogUpdate
|
|
||||||
|
|
||||||
// SettleFailFilter is a filter containing the indices of all Settle or
|
|
||||||
// Fails originating in this package that have been received and locked
|
|
||||||
// into the incoming link's commitment state.
|
|
||||||
SettleFailFilter *PkgFilter
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This
|
|
||||||
// should be used to create a package at the time we receive a revocation.
|
|
||||||
func NewFwdPkg(source lnwire.ShortChannelID, height uint64,
|
|
||||||
addUpdates, settleFailUpdates []LogUpdate) *FwdPkg {
|
|
||||||
|
|
||||||
nAddUpdates := uint16(len(addUpdates))
|
|
||||||
nSettleFailUpdates := uint16(len(settleFailUpdates))
|
|
||||||
|
|
||||||
return &FwdPkg{
|
|
||||||
Source: source,
|
|
||||||
Height: height,
|
|
||||||
State: FwdStateLockedIn,
|
|
||||||
Adds: addUpdates,
|
|
||||||
FwdFilter: NewPkgFilter(nAddUpdates),
|
|
||||||
AckFilter: NewPkgFilter(nAddUpdates),
|
|
||||||
SettleFails: settleFailUpdates,
|
|
||||||
SettleFailFilter: NewPkgFilter(nSettleFailUpdates),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ID returns an unique identifier for this package, used to ensure that sphinx
|
|
||||||
// replay processing of this batch is idempotent.
|
|
||||||
func (f *FwdPkg) ID() []byte {
|
|
||||||
var id = make([]byte, 16)
|
|
||||||
byteOrder.PutUint64(id[:8], f.Source.ToUint64())
|
|
||||||
byteOrder.PutUint64(id[8:], f.Height)
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a human-readable description of the forwarding package.
|
|
||||||
func (f *FwdPkg) String() string {
|
|
||||||
return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)",
|
|
||||||
f, f.Source, f.Height, len(f.Adds), len(f.SettleFails))
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID
|
|
||||||
// is assumed to be that of the packager.
|
|
||||||
type AddRef struct {
|
|
||||||
// Height is the remote commitment height that locked in the Add.
|
|
||||||
Height uint64
|
|
||||||
|
|
||||||
// Index is the index of the Add within the fwd pkg's Adds.
|
|
||||||
//
|
|
||||||
// NOTE: This index is static over the lifetime of a forwarding package.
|
|
||||||
Index uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode serializes the AddRef to the given io.Writer.
|
|
||||||
func (a *AddRef) Encode(w io.Writer) error {
|
|
||||||
if err := binary.Write(w, binary.BigEndian, a.Height); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return binary.Write(w, binary.BigEndian, a.Index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode deserializes the AddRef from the given io.Reader.
|
|
||||||
func (a *AddRef) Decode(r io.Reader) error {
|
|
||||||
if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return binary.Read(r, binary.BigEndian, &a.Index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A
|
|
||||||
// channel does not remove its own Settle/Fail htlcs, so the source is provided
|
|
||||||
// to locate a db bucket belonging to another channel.
|
|
||||||
type SettleFailRef struct {
|
|
||||||
// Source identifies the outgoing link that locked in the settle or
|
|
||||||
// fail. This is then used by the *incoming* link to find the settle
|
|
||||||
// fail in another link's forwarding packages.
|
|
||||||
Source lnwire.ShortChannelID
|
|
||||||
|
|
||||||
// Height is the remote commitment height that locked in this
|
|
||||||
// Settle/Fail.
|
|
||||||
Height uint64
|
|
||||||
|
|
||||||
// Index is the index of the Add with the fwd pkg's SettleFails.
|
|
||||||
//
|
|
||||||
// NOTE: This index is static over the lifetime of a forwarding package.
|
|
||||||
Index uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// SettleFailAcker is a generic interface providing the ability to acknowledge
|
|
||||||
// settle/fail HTLCs stored in forwarding packages.
|
|
||||||
type SettleFailAcker interface {
|
|
||||||
// AckSettleFails atomically updates the settle-fail filters in *other*
|
|
||||||
// channels' forwarding packages.
|
|
||||||
AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages
|
|
||||||
// of any active channel.
|
|
||||||
type GlobalFwdPkgReader interface {
|
|
||||||
// LoadChannelFwdPkgs loads all known forwarding packages for the given
|
|
||||||
// channel.
|
|
||||||
LoadChannelFwdPkgs(tx *bbolt.Tx,
|
|
||||||
source lnwire.ShortChannelID) ([]*FwdPkg, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FwdOperator defines the interfaces for managing forwarding packages that are
|
|
||||||
// external to a particular channel. This interface is used by the switch to
|
|
||||||
// read forwarding packages from arbitrary channels, and acknowledge settles and
|
|
||||||
// fails for locally-sourced payments.
|
|
||||||
type FwdOperator interface {
|
|
||||||
// GlobalFwdPkgReader provides read access to all known forwarding
|
|
||||||
// packages
|
|
||||||
GlobalFwdPkgReader
|
|
||||||
|
|
||||||
// SettleFailAcker grants the ability to acknowledge settles or fails
|
|
||||||
// residing in arbitrary forwarding packages.
|
|
||||||
SettleFailAcker
|
|
||||||
}
|
|
||||||
|
|
||||||
// SwitchPackager is a concrete implementation of the FwdOperator interface.
|
|
||||||
// A SwitchPackager offers the ability to read any forwarding package, and ack
|
|
||||||
// arbitrary settle and fail HTLCs.
|
|
||||||
type SwitchPackager struct{}
|
|
||||||
|
|
||||||
// NewSwitchPackager instantiates a new SwitchPackager.
|
|
||||||
func NewSwitchPackager() *SwitchPackager {
|
|
||||||
return &SwitchPackager{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AckSettleFails atomically updates the settle-fail filters in *other*
|
|
||||||
// channels' forwarding packages, to mark that the switch has received a settle
|
|
||||||
// or fail residing in the forwarding package of a link.
|
|
||||||
func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx,
|
|
||||||
settleFailRefs ...SettleFailRef) error {
|
|
||||||
|
|
||||||
return ackSettleFails(tx, settleFailRefs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadChannelFwdPkgs loads all forwarding packages for a particular channel.
|
|
||||||
func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx,
|
|
||||||
source lnwire.ShortChannelID) ([]*FwdPkg, error) {
|
|
||||||
|
|
||||||
return loadChannelFwdPkgs(tx, source)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FwdPackager supports all operations required to modify fwd packages, such as
|
|
||||||
// creation, updates, reading, and removal. The interfaces are broken down in
|
|
||||||
// this way to support future delegation of the subinterfaces.
|
|
||||||
type FwdPackager interface {
|
|
||||||
// AddFwdPkg serializes and writes a FwdPkg for this channel at the
|
|
||||||
// remote commitment height included in the forwarding package.
|
|
||||||
AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error
|
|
||||||
|
|
||||||
// SetFwdFilter looks up the forwarding package at the remote `height`
|
|
||||||
// and sets the `fwdFilter`, marking the Adds for which:
|
|
||||||
// 1) We are not the exit node
|
|
||||||
// 2) Passed all validation
|
|
||||||
// 3) Should be forwarded to the switch immediately after a failure
|
|
||||||
SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error
|
|
||||||
|
|
||||||
// AckAddHtlcs atomically updates the add filters in this channel's
|
|
||||||
// forwarding packages to mark the resolution of an Add that was
|
|
||||||
// received from the remote party.
|
|
||||||
AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error
|
|
||||||
|
|
||||||
// SettleFailAcker allows a link to acknowledge settle/fail HTLCs
|
|
||||||
// belonging to other channels.
|
|
||||||
SettleFailAcker
|
|
||||||
|
|
||||||
// LoadFwdPkgs loads all known forwarding packages owned by this
|
|
||||||
// channel.
|
|
||||||
LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error)
|
|
||||||
|
|
||||||
// RemovePkg deletes a forwarding package owned by this channel at
|
|
||||||
// the provided remote `height`.
|
|
||||||
RemovePkg(tx *bbolt.Tx, height uint64) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelPackager is used by a channel to manage the lifecycle of its forwarding
|
|
||||||
// packages. The packager is tied to a particular source channel ID, allowing it
|
|
||||||
// to create and edit its own packages. Each packager also has the ability to
|
|
||||||
// remove fail/settle htlcs that correspond to an add contained in one of
|
|
||||||
// source's packages.
|
|
||||||
type ChannelPackager struct {
|
|
||||||
source lnwire.ShortChannelID
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewChannelPackager creates a new packager for a single channel.
|
|
||||||
func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager {
|
|
||||||
return &ChannelPackager{
|
|
||||||
source: source,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddFwdPkg writes a newly locked in forwarding package to disk.
|
|
||||||
func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error {
|
|
||||||
fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
source := makeLogKey(fwdPkg.Source.ToUint64())
|
|
||||||
sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
heightKey := makeLogKey(fwdPkg.Height)
|
|
||||||
heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write ADD updates we received at this commit height.
|
|
||||||
addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write SETTLE/FAIL updates we received at this commit height.
|
|
||||||
failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range fwdPkg.Adds {
|
|
||||||
err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Persist the initialized pkg filter, which will be used to determine
|
|
||||||
// when we can remove this forwarding package from disk.
|
|
||||||
var ackFilterBuf bytes.Buffer
|
|
||||||
if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range fwdPkg.SettleFails {
|
|
||||||
err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var settleFailFilterBuf bytes.Buffer
|
|
||||||
err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key.
|
|
||||||
func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := htlc.Encode(&b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return bkt.Put(uint16Key(idx), b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadFwdPkgs scans the forwarding log for any packages that haven't been
|
|
||||||
// processed, and returns their deserialized log updates in a map indexed by the
|
|
||||||
// remote commitment height at which the updates were locked in.
|
|
||||||
func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) {
|
|
||||||
return loadChannelFwdPkgs(tx, p.source)
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadChannelFwdPkgs loads all forwarding packages owned by `source`.
|
|
||||||
func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) {
|
|
||||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
|
||||||
if fwdPkgBkt == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceKey := makeLogKey(source.ToUint64())
|
|
||||||
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
|
|
||||||
if sourceBkt == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var heights []uint64
|
|
||||||
if err := sourceBkt.ForEach(func(k, _ []byte) error {
|
|
||||||
if len(k) != 8 {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
heights = append(heights, byteOrder.Uint64(k))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load the forwarding package for each retrieved height.
|
|
||||||
fwdPkgs := make([]*FwdPkg, 0, len(heights))
|
|
||||||
for _, height := range heights {
|
|
||||||
fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
fwdPkgs = append(fwdPkgs, fwdPkg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fwdPkgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadFwPkg reads the packager's fwd pkg at a given height, and determines the
|
|
||||||
// appropriate FwdState.
|
|
||||||
func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID,
|
|
||||||
height uint64) (*FwdPkg, error) {
|
|
||||||
|
|
||||||
sourceKey := makeLogKey(source.ToUint64())
|
|
||||||
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
|
|
||||||
if sourceBkt == nil {
|
|
||||||
return nil, ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
heightKey := makeLogKey(height)
|
|
||||||
heightBkt := sourceBkt.Bucket(heightKey[:])
|
|
||||||
if heightBkt == nil {
|
|
||||||
return nil, ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load ADDs from disk.
|
|
||||||
addBkt := heightBkt.Bucket(addBucketKey)
|
|
||||||
if addBkt == nil {
|
|
||||||
return nil, ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
adds, err := loadHtlcs(addBkt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load ack filter from disk.
|
|
||||||
ackFilterBytes := heightBkt.Get(ackFilterKey)
|
|
||||||
if ackFilterBytes == nil {
|
|
||||||
return nil, ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
ackFilterReader := bytes.NewReader(ackFilterBytes)
|
|
||||||
|
|
||||||
ackFilter := &PkgFilter{}
|
|
||||||
if err := ackFilter.Decode(ackFilterReader); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load SETTLE/FAILs from disk.
|
|
||||||
failSettleBkt := heightBkt.Bucket(failSettleBucketKey)
|
|
||||||
if failSettleBkt == nil {
|
|
||||||
return nil, ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
failSettles, err := loadHtlcs(failSettleBkt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load settle fail filter from disk.
|
|
||||||
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
|
|
||||||
if settleFailFilterBytes == nil {
|
|
||||||
return nil, ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
|
|
||||||
|
|
||||||
settleFailFilter := &PkgFilter{}
|
|
||||||
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the fwding package, which always starts in the
|
|
||||||
// FwdStateLockedIn. We can determine what state the package was left in
|
|
||||||
// by examining constraints on the information loaded from disk.
|
|
||||||
fwdPkg := &FwdPkg{
|
|
||||||
Source: source,
|
|
||||||
State: FwdStateLockedIn,
|
|
||||||
Height: height,
|
|
||||||
Adds: adds,
|
|
||||||
AckFilter: ackFilter,
|
|
||||||
SettleFails: failSettles,
|
|
||||||
SettleFailFilter: settleFailFilter,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check to see if we have written the set exported filter adds to
|
|
||||||
// disk. If we haven't, processing of this package was never started, or
|
|
||||||
// failed during the last attempt.
|
|
||||||
fwdFilterBytes := heightBkt.Get(fwdFilterKey)
|
|
||||||
if fwdFilterBytes == nil {
|
|
||||||
nAdds := uint16(len(adds))
|
|
||||||
fwdPkg.FwdFilter = NewPkgFilter(nAdds)
|
|
||||||
return fwdPkg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fwdFilterReader := bytes.NewReader(fwdFilterBytes)
|
|
||||||
fwdPkg.FwdFilter = &PkgFilter{}
|
|
||||||
if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, a complete round of processing was completed, and we
|
|
||||||
// advance the package to FwdStateProcessed.
|
|
||||||
fwdPkg.State = FwdStateProcessed
|
|
||||||
|
|
||||||
// If every add, settle, and fail has been fully acknowledged, we can
|
|
||||||
// safely set the package's state to FwdStateCompleted, signalling that
|
|
||||||
// it can be garbage collected.
|
|
||||||
if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() {
|
|
||||||
fwdPkg.State = FwdStateCompleted
|
|
||||||
}
|
|
||||||
|
|
||||||
return fwdPkg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadHtlcs retrieves all serialized htlcs in a bucket, returning
|
|
||||||
// them in order of the indexes they were written under.
|
|
||||||
func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) {
|
|
||||||
var htlcs []LogUpdate
|
|
||||||
if err := bkt.ForEach(func(_, v []byte) error {
|
|
||||||
var htlc LogUpdate
|
|
||||||
if err := htlc.Decode(bytes.NewReader(v)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
htlcs = append(htlcs, htlc)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return htlcs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFwdFilter writes the set of indexes corresponding to Adds at the
|
|
||||||
// `height` that are to be forwarded to the switch. Calling this method causes
|
|
||||||
// the forwarding package at `height` to be in FwdStateProcessed. We write this
|
|
||||||
// forwarding decision so that we always arrive at the same behavior for HTLCs
|
|
||||||
// leaving this channel. After a restart, we skip validation of these Adds,
|
|
||||||
// since they are assumed to have already been validated, and make the switch or
|
|
||||||
// outgoing link responsible for handling replays.
|
|
||||||
func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64,
|
|
||||||
fwdFilter *PkgFilter) error {
|
|
||||||
|
|
||||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
|
||||||
if fwdPkgBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
source := makeLogKey(p.source.ToUint64())
|
|
||||||
sourceBkt := fwdPkgBkt.Bucket(source[:])
|
|
||||||
if sourceBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
heightKey := makeLogKey(height)
|
|
||||||
heightBkt := sourceBkt.Bucket(heightKey[:])
|
|
||||||
if heightBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the fwd filter has already been written, we return early to avoid
|
|
||||||
// modifying the persistent state.
|
|
||||||
forwardedAddsBytes := heightBkt.Get(fwdFilterKey)
|
|
||||||
if forwardedAddsBytes != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise we serialize and write the provided fwd filter.
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := fwdFilter.Encode(&b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return heightBkt.Put(fwdFilterKey, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
// AckAddHtlcs accepts a list of references to add htlcs, and updates the
|
|
||||||
// AckAddFilter of those forwarding packages to indicate that a settle or fail
|
|
||||||
// has been received in response to the add.
|
|
||||||
func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error {
|
|
||||||
if len(addRefs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
|
||||||
if fwdPkgBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceKey := makeLogKey(p.source.ToUint64())
|
|
||||||
sourceBkt := fwdPkgBkt.Bucket(sourceKey[:])
|
|
||||||
if sourceBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Organize the forward references such that we just get a single slice
|
|
||||||
// of indexes for each unique height.
|
|
||||||
heightDiffs := make(map[uint64][]uint16)
|
|
||||||
for _, addRef := range addRefs {
|
|
||||||
heightDiffs[addRef.Height] = append(
|
|
||||||
heightDiffs[addRef.Height],
|
|
||||||
addRef.Index,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load each height bucket once and remove all acked htlcs at that
|
|
||||||
// height.
|
|
||||||
for height, indexes := range heightDiffs {
|
|
||||||
err := ackAddHtlcsAtHeight(sourceBkt, height, indexes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package
|
|
||||||
// with a list of indexes, writing the resulting filter back in its place.
|
|
||||||
func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64,
|
|
||||||
indexes []uint16) error {
|
|
||||||
|
|
||||||
heightKey := makeLogKey(height)
|
|
||||||
heightBkt := sourceBkt.Bucket(heightKey[:])
|
|
||||||
if heightBkt == nil {
|
|
||||||
// If the height bucket isn't found, this could be because the
|
|
||||||
// forwarding package was already removed. We'll return nil to
|
|
||||||
// signal that the operation is successful, as there is nothing
|
|
||||||
// to ack.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load ack filter from disk.
|
|
||||||
ackFilterBytes := heightBkt.Get(ackFilterKey)
|
|
||||||
if ackFilterBytes == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
ackFilter := &PkgFilter{}
|
|
||||||
ackFilterReader := bytes.NewReader(ackFilterBytes)
|
|
||||||
if err := ackFilter.Decode(ackFilterReader); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the ack filter for this height.
|
|
||||||
for _, index := range indexes {
|
|
||||||
ackFilter.Set(index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the resulting filter to disk.
|
|
||||||
var ackFilterBuf bytes.Buffer
|
|
||||||
if err := ackFilter.Encode(&ackFilterBuf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
// AckSettleFails persistently acknowledges settles or fails from a remote forwarding
|
|
||||||
// package. This should only be called after the source of the Add has locked in
|
|
||||||
// the settle/fail, or it becomes otherwise safe to forgo retransmitting the
|
|
||||||
// settle/fail after a restart.
|
|
||||||
func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error {
|
|
||||||
return ackSettleFails(tx, settleFailRefs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackSettleFails persistently acknowledges a batch of settle fail references.
|
|
||||||
func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error {
|
|
||||||
if len(settleFailRefs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
|
||||||
if fwdPkgBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Organize the forward references such that we just get a single slice
|
|
||||||
// of indexes for each unique destination-height pair.
|
|
||||||
destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16)
|
|
||||||
for _, settleFailRef := range settleFailRefs {
|
|
||||||
destHeights, ok := destHeightDiffs[settleFailRef.Source]
|
|
||||||
if !ok {
|
|
||||||
destHeights = make(map[uint64][]uint16)
|
|
||||||
destHeightDiffs[settleFailRef.Source] = destHeights
|
|
||||||
}
|
|
||||||
|
|
||||||
destHeights[settleFailRef.Height] = append(
|
|
||||||
destHeights[settleFailRef.Height],
|
|
||||||
settleFailRef.Index,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the references organized by destination and height, we now load
|
|
||||||
// each remote bucket, and update the settle fail filter for any
|
|
||||||
// settle/fail htlcs.
|
|
||||||
for dest, destHeights := range destHeightDiffs {
|
|
||||||
destKey := makeLogKey(dest.ToUint64())
|
|
||||||
destBkt := fwdPkgBkt.Bucket(destKey[:])
|
|
||||||
if destBkt == nil {
|
|
||||||
// If the destination bucket is not found, this is
|
|
||||||
// likely the result of the destination channel being
|
|
||||||
// closed and having it's forwarding packages wiped. We
|
|
||||||
// won't treat this as an error, because the response
|
|
||||||
// will no longer be retransmitted internally.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for height, indexes := range destHeights {
|
|
||||||
err := ackSettleFailsAtHeight(destBkt, height, indexes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes
|
|
||||||
// at particular a height by updating the settle fail filter.
|
|
||||||
func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64,
|
|
||||||
indexes []uint16) error {
|
|
||||||
|
|
||||||
heightKey := makeLogKey(height)
|
|
||||||
heightBkt := destBkt.Bucket(heightKey[:])
|
|
||||||
if heightBkt == nil {
|
|
||||||
// If the height bucket isn't found, this could be because the
|
|
||||||
// forwarding package was already removed. We'll return nil to
|
|
||||||
// signal that the operation is as there is nothing to ack.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load ack filter from disk.
|
|
||||||
settleFailFilterBytes := heightBkt.Get(settleFailFilterKey)
|
|
||||||
if settleFailFilterBytes == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
settleFailFilter := &PkgFilter{}
|
|
||||||
settleFailFilterReader := bytes.NewReader(settleFailFilterBytes)
|
|
||||||
if err := settleFailFilter.Decode(settleFailFilterReader); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the ack filter for this height.
|
|
||||||
for _, index := range indexes {
|
|
||||||
settleFailFilter.Set(index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write the resulting filter to disk.
|
|
||||||
var settleFailFilterBuf bytes.Buffer
|
|
||||||
if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePkg deletes the forwarding package at the given height from the
|
|
||||||
// packager's source bucket.
|
|
||||||
func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error {
|
|
||||||
fwdPkgBkt := tx.Bucket(fwdPackagesKey)
|
|
||||||
if fwdPkgBkt == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sourceBytes := makeLogKey(p.source.ToUint64())
|
|
||||||
sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:])
|
|
||||||
if sourceBkt == nil {
|
|
||||||
return ErrCorruptedFwdPkg
|
|
||||||
}
|
|
||||||
|
|
||||||
heightKey := makeLogKey(height)
|
|
||||||
|
|
||||||
return sourceBkt.DeleteBucket(heightKey[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice.
|
|
||||||
func uint16Key(i uint16) []byte {
|
|
||||||
key := make([]byte, 2)
|
|
||||||
byteOrder.PutUint16(key, i)
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile-time constraint to ensure that ChannelPackager implements the public
|
|
||||||
// FwdPackager interface.
|
|
||||||
var _ FwdPackager = (*ChannelPackager)(nil)
|
|
||||||
|
|
||||||
// Compile-time constraint to ensure that SwitchPackager implements the public
|
|
||||||
// FwdOperator interface.
|
|
||||||
var _ FwdOperator = (*SwitchPackager)(nil)
|
|
@ -1,815 +0,0 @@
|
|||||||
package migration_01_to_11_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io/ioutil"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000,
|
|
||||||
// which is greater than the number of HTLCs we permit on a commitment txn.
|
|
||||||
// This should encapsulate every potential filter used in practice.
|
|
||||||
func TestPkgFilterBruteForce(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
checkPkgFilterRange(t, 1000)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear
|
|
||||||
// insertion of `high` elements. This is primarily to test that IsFull functions
|
|
||||||
// properly for all relevant sizes of `high`.
|
|
||||||
func checkPkgFilterRange(t *testing.T, high int) {
|
|
||||||
for i := uint16(0); i < uint16(high); i++ {
|
|
||||||
f := channeldb.NewPkgFilter(i)
|
|
||||||
|
|
||||||
if f.Count() != i {
|
|
||||||
t.Fatalf("pkg filter count=%d is actually %d",
|
|
||||||
i, f.Count())
|
|
||||||
}
|
|
||||||
checkPkgFilterEncodeDecode(t, i, f)
|
|
||||||
|
|
||||||
for j := uint16(0); j < i; j++ {
|
|
||||||
if f.Contains(j) {
|
|
||||||
t.Fatalf("pkg filter count=%d contains %d "+
|
|
||||||
"before being added", i, j)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.Set(j)
|
|
||||||
checkPkgFilterEncodeDecode(t, i, f)
|
|
||||||
|
|
||||||
if !f.Contains(j) {
|
|
||||||
t.Fatalf("pkg filter count=%d missing %d "+
|
|
||||||
"after being added", i, j)
|
|
||||||
}
|
|
||||||
|
|
||||||
if j < i-1 && f.IsFull() {
|
|
||||||
t.Fatalf("pkg filter count=%d already full", i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !f.IsFull() {
|
|
||||||
t.Fatalf("pkg filter count=%d not full", i)
|
|
||||||
}
|
|
||||||
checkPkgFilterEncodeDecode(t, i, f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPkgFilterRand uses a random permutation to verify the proper behavior of
|
|
||||||
// the pkg filter if the entries are not inserted in-order.
|
|
||||||
func TestPkgFilterRand(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
checkPkgFilterRand(t, 3, 17)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting
|
|
||||||
// indices and asserting the invariants. The order in which indices are inserted
|
|
||||||
// is parameterized by a base `b` coprime to `p`, and using modular
|
|
||||||
// exponentiation to generate all elements in [1,p).
|
|
||||||
func checkPkgFilterRand(t *testing.T, b, p uint16) {
|
|
||||||
f := channeldb.NewPkgFilter(p)
|
|
||||||
var j = b
|
|
||||||
for i := uint16(1); i < p; i++ {
|
|
||||||
if f.Contains(j) {
|
|
||||||
t.Fatalf("pkg filter contains %d-%d "+
|
|
||||||
"before being added", i, j)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.Set(j)
|
|
||||||
checkPkgFilterEncodeDecode(t, i, f)
|
|
||||||
|
|
||||||
if !f.Contains(j) {
|
|
||||||
t.Fatalf("pkg filter missing %d-%d "+
|
|
||||||
"after being added", i, j)
|
|
||||||
}
|
|
||||||
|
|
||||||
if i < p-1 && f.IsFull() {
|
|
||||||
t.Fatalf("pkg filter %d already full", i)
|
|
||||||
}
|
|
||||||
checkPkgFilterEncodeDecode(t, i, f)
|
|
||||||
|
|
||||||
j = (b * j) % p
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set 0 independently, since it will never be emitted by the generator.
|
|
||||||
f.Set(0)
|
|
||||||
checkPkgFilterEncodeDecode(t, p, f)
|
|
||||||
|
|
||||||
if !f.IsFull() {
|
|
||||||
t.Fatalf("pkg filter count=%d not full", p)
|
|
||||||
}
|
|
||||||
checkPkgFilterEncodeDecode(t, p, f)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by:
|
|
||||||
// 1) writing it to a buffer
|
|
||||||
// 2) verifying the number of bytes written matches the filter's Size()
|
|
||||||
// 3) reconstructing the filter decoding the bytes
|
|
||||||
// 4) checking that the two filters are the same according to Equal
|
|
||||||
func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := f.Encode(&b); err != nil {
|
|
||||||
t.Fatalf("unable to serialize pkg filter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// +2 for uint16 length
|
|
||||||
size := uint16(len(b.Bytes()))
|
|
||||||
if size != f.Size() {
|
|
||||||
t.Fatalf("pkg filter count=%d serialized size differs, "+
|
|
||||||
"Size(): %d, len(bytes): %v", i, f.Size(), size)
|
|
||||||
}
|
|
||||||
|
|
||||||
reader := bytes.NewReader(b.Bytes())
|
|
||||||
|
|
||||||
f2 := &channeldb.PkgFilter{}
|
|
||||||
if err := f2.Decode(reader); err != nil {
|
|
||||||
t.Fatalf("unable to deserialize pkg filter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !f.Equal(f2) {
|
|
||||||
t.Fatalf("pkg filter count=%v does is not equal "+
|
|
||||||
"after deserialization, want: %v, got %v",
|
|
||||||
i, f, f2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{})
|
|
||||||
|
|
||||||
adds = []channeldb.LogUpdate{
|
|
||||||
{
|
|
||||||
LogIndex: 0,
|
|
||||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
|
||||||
ChanID: chanID,
|
|
||||||
ID: 1,
|
|
||||||
Amount: 100,
|
|
||||||
Expiry: 1000,
|
|
||||||
PaymentHash: [32]byte{0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
LogIndex: 1,
|
|
||||||
UpdateMsg: &lnwire.UpdateAddHTLC{
|
|
||||||
ChanID: chanID,
|
|
||||||
ID: 1,
|
|
||||||
Amount: 101,
|
|
||||||
Expiry: 1001,
|
|
||||||
PaymentHash: [32]byte{1},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
settleFails = []channeldb.LogUpdate{
|
|
||||||
{
|
|
||||||
LogIndex: 2,
|
|
||||||
UpdateMsg: &lnwire.UpdateFulfillHTLC{
|
|
||||||
ChanID: chanID,
|
|
||||||
ID: 0,
|
|
||||||
PaymentPreimage: [32]byte{0},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
LogIndex: 3,
|
|
||||||
UpdateMsg: &lnwire.UpdateFailHTLC{
|
|
||||||
ChanID: chanID,
|
|
||||||
ID: 1,
|
|
||||||
Reason: []byte{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a
|
|
||||||
// forwarding package that contains no adds, fails or settles. We expect that
|
|
||||||
// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding
|
|
||||||
// decision via SetFwdFilter.
|
|
||||||
func TestPackagerEmptyFwdPkg(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db := makeFwdPkgDB(t, "")
|
|
||||||
|
|
||||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
|
||||||
packager := channeldb.NewChannelPackager(shortChanID)
|
|
||||||
|
|
||||||
// To begin, there should be no forwarding packages on disk.
|
|
||||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, create and write a new forwarding package with no htlcs.
|
|
||||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil)
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AddFwdPkg(tx, fwdPkg)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
|
||||||
// has been written, we expect it to be FwdStateLockedIn. With no HTLCs,
|
|
||||||
// the ack filter will have no elements, and should always return true.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Now, write the forwarding decision. In this case, its just an empty
|
|
||||||
// fwd filter.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should still have one package on disk. Since the forwarding
|
|
||||||
// decision has been written, it will minimally be in FwdStateProcessed.
|
|
||||||
// However with no htlcs, it should leap frog to FwdStateCompleted.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Lastly, remove the completed forwarding package from disk.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the fwd package was actually removed.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted
|
|
||||||
// as soon as all the adds in the package have been acked using AckAddHtlcs.
|
|
||||||
func TestPackagerOnlyAdds(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db := makeFwdPkgDB(t, "")
|
|
||||||
|
|
||||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
|
||||||
packager := channeldb.NewChannelPackager(shortChanID)
|
|
||||||
|
|
||||||
// To begin, there should be no forwarding packages on disk.
|
|
||||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, create and write a new forwarding package that only has add
|
|
||||||
// htlcs.
|
|
||||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil)
|
|
||||||
|
|
||||||
nAdds := len(adds)
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AddFwdPkg(tx, fwdPkg)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
|
||||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
|
||||||
// has unacked add HTLCs, so the ack filter should not be full.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
// Now, write the forwarding decision. Since we have not explicitly
|
|
||||||
// added any adds to the fwdfilter, this would indicate that all of the
|
|
||||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
|
||||||
// was failed locally.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range adds {
|
|
||||||
// We should still have one package on disk. Since the forwarding
|
|
||||||
// decision has been written, it will minimally be in FwdStateProcessed.
|
|
||||||
// However not allf of the HTLCs have been acked, so should not
|
|
||||||
// have advanced further.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
addRef := channeldb.AddRef{
|
|
||||||
Height: fwdPkg.Height,
|
|
||||||
Index: uint16(i),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AckAddHtlcs(tx, addRef)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to ack add htlc: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should still have one package on disk. Now that all adds have been
|
|
||||||
// acked, the ack filter should return true and the package should be
|
|
||||||
// FwdStateCompleted since there are no other settle/fail packets.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Lastly, remove the completed forwarding package from disk.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the fwd package was actually removed.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPackagerOnlySettleFails asserts that the fwdpkg remains in
|
|
||||||
// FwdStateProcessed after writing the forwarding decision when there are no
|
|
||||||
// adds in the fwdpkg. We expect this because an empty FwdFilter will always
|
|
||||||
// return true, but we are still waiting for the remaining fails and settles to
|
|
||||||
// be deleted.
|
|
||||||
func TestPackagerOnlySettleFails(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db := makeFwdPkgDB(t, "")
|
|
||||||
|
|
||||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
|
||||||
packager := channeldb.NewChannelPackager(shortChanID)
|
|
||||||
|
|
||||||
// To begin, there should be no forwarding packages on disk.
|
|
||||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, create and write a new forwarding package that only has add
|
|
||||||
// htlcs.
|
|
||||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails)
|
|
||||||
|
|
||||||
nSettleFails := len(settleFails)
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AddFwdPkg(tx, fwdPkg)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
|
||||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
|
||||||
// has unacked add HTLCs, so the ack filter should not be full.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Now, write the forwarding decision. Since we have not explicitly
|
|
||||||
// added any adds to the fwdfilter, this would indicate that all of the
|
|
||||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
|
||||||
// was failed locally.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range settleFails {
|
|
||||||
// We should still have one package on disk. Since the
|
|
||||||
// forwarding decision has been written, it will minimally be in
|
|
||||||
// FwdStateProcessed. However, not all of the HTLCs have been
|
|
||||||
// acked, so should not have advanced further.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
failSettleRef := channeldb.SettleFailRef{
|
|
||||||
Source: shortChanID,
|
|
||||||
Height: fwdPkg.Height,
|
|
||||||
Index: uint16(i),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AckSettleFails(tx, failSettleRef)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to ack add htlc: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should still have one package on disk. Now that all settles and
|
|
||||||
// fails have been removed, package should be FwdStateCompleted since
|
|
||||||
// there are no other add packets.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Lastly, remove the completed forwarding package from disk.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the fwd package was actually removed.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and
|
|
||||||
// settle/fails, then checks the behavior when the adds are acked before any of
|
|
||||||
// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the
|
|
||||||
// remainder of the fail/settles are being deleted.
|
|
||||||
func TestPackagerAddsThenSettleFails(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db := makeFwdPkgDB(t, "")
|
|
||||||
|
|
||||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
|
||||||
packager := channeldb.NewChannelPackager(shortChanID)
|
|
||||||
|
|
||||||
// To begin, there should be no forwarding packages on disk.
|
|
||||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, create and write a new forwarding package that only has add
|
|
||||||
// htlcs.
|
|
||||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails)
|
|
||||||
|
|
||||||
nAdds := len(adds)
|
|
||||||
nSettleFails := len(settleFails)
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AddFwdPkg(tx, fwdPkg)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
|
||||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
|
||||||
// has unacked add HTLCs, so the ack filter should not be full.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
// Now, write the forwarding decision. Since we have not explicitly
|
|
||||||
// added any adds to the fwdfilter, this would indicate that all of the
|
|
||||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
|
||||||
// was failed locally.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range adds {
|
|
||||||
// We should still have one package on disk. Since the forwarding
|
|
||||||
// decision has been written, it will minimally be in FwdStateProcessed.
|
|
||||||
// However not allf of the HTLCs have been acked, so should not
|
|
||||||
// have advanced further.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
addRef := channeldb.AddRef{
|
|
||||||
Height: fwdPkg.Height,
|
|
||||||
Index: uint16(i),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AckAddHtlcs(tx, addRef)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to ack add htlc: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range settleFails {
|
|
||||||
// We should still have one package on disk. Since the
|
|
||||||
// forwarding decision has been written, it will minimally be in
|
|
||||||
// FwdStateProcessed. However not allf of the HTLCs have been
|
|
||||||
// acked, so should not have advanced further.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
failSettleRef := channeldb.SettleFailRef{
|
|
||||||
Source: shortChanID,
|
|
||||||
Height: fwdPkg.Height,
|
|
||||||
Index: uint16(i),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AckSettleFails(tx, failSettleRef)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove settle/fail htlc: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should still have one package on disk. Now that all settles and
|
|
||||||
// fails have been removed, package should be FwdStateCompleted since
|
|
||||||
// there are no other add packets.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Lastly, remove the completed forwarding package from disk.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the fwd package was actually removed.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and
|
|
||||||
// settle/fails, then checks the behavior when the settle/fails are removed
|
|
||||||
// before any of the adds have been acked. This should cause the fwdpkg to
|
|
||||||
// remain in FwdStateProcessed until the final ack is recorded, at which point
|
|
||||||
// it should be promoted directly to FwdStateCompleted.since all adds have been
|
|
||||||
// removed.
|
|
||||||
func TestPackagerSettleFailsThenAdds(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db := makeFwdPkgDB(t, "")
|
|
||||||
|
|
||||||
shortChanID := lnwire.NewShortChanIDFromInt(1)
|
|
||||||
packager := channeldb.NewChannelPackager(shortChanID)
|
|
||||||
|
|
||||||
// To begin, there should be no forwarding packages on disk.
|
|
||||||
fwdPkgs := loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, create and write a new forwarding package that has both add
|
|
||||||
// and settle/fail htlcs.
|
|
||||||
fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails)
|
|
||||||
|
|
||||||
nAdds := len(adds)
|
|
||||||
nSettleFails := len(settleFails)
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AddFwdPkg(tx, fwdPkg)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to add fwd pkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// There should now be one fwdpkg on disk. Since no forwarding decision
|
|
||||||
// has been written, we expect it to be FwdStateLockedIn. The package
|
|
||||||
// has unacked add HTLCs, so the ack filter should not be full.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
// Now, write the forwarding decision. Since we have not explicitly
|
|
||||||
// added any adds to the fwdfilter, this would indicate that all of the
|
|
||||||
// adds were 1) settled locally by this link (exit hop), or 2) the htlc
|
|
||||||
// was failed locally.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to set fwdfiter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate another channel deleting the settle/fails it received from
|
|
||||||
// the original fwd pkg.
|
|
||||||
// TODO(conner): use different packager/s?
|
|
||||||
for i := range settleFails {
|
|
||||||
// We should still have one package on disk. Since the
|
|
||||||
// forwarding decision has been written, it will minimally be in
|
|
||||||
// FwdStateProcessed. However none all of the add HTLCs have
|
|
||||||
// been acked, so should not have advanced further.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
failSettleRef := channeldb.SettleFailRef{
|
|
||||||
Source: shortChanID,
|
|
||||||
Height: fwdPkg.Height,
|
|
||||||
Index: uint16(i),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AckSettleFails(tx, failSettleRef)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove settle/fail htlc: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now simulate this channel receiving a fail/settle for the adds in the
|
|
||||||
// fwdpkg.
|
|
||||||
for i := range adds {
|
|
||||||
// Again, we should still have one package on disk and be in
|
|
||||||
// FwdStateProcessed. This should not change until all of the
|
|
||||||
// add htlcs have been acked.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], false)
|
|
||||||
|
|
||||||
addRef := channeldb.AddRef{
|
|
||||||
Height: fwdPkg.Height,
|
|
||||||
Index: uint16(i),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.AckAddHtlcs(tx, addRef)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to ack add htlc: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should still have one package on disk. Now that all settles and
|
|
||||||
// fails have been removed, package should be FwdStateCompleted since
|
|
||||||
// there are no other add packets.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 1 {
|
|
||||||
t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted)
|
|
||||||
assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails)
|
|
||||||
assertSettleFailFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
assertAckFilterIsFull(t, fwdPkgs[0], true)
|
|
||||||
|
|
||||||
// Lastly, remove the completed forwarding package from disk.
|
|
||||||
if err := db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return packager.RemovePkg(tx, fwdPkg.Height)
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to remove fwdpkg: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that the fwd package was actually removed.
|
|
||||||
fwdPkgs = loadFwdPkgs(t, db, packager)
|
|
||||||
if len(fwdPkgs) != 0 {
|
|
||||||
t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertFwdPkgState checks the current state of a fwdpkg meets our
|
|
||||||
// expectations.
|
|
||||||
func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg,
|
|
||||||
state channeldb.FwdState) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
if fwdPkg.State != state {
|
|
||||||
t.Fatalf("line %d: expected fwdpkg in state %v, found %v",
|
|
||||||
line, state, fwdPkg.State)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertFwdPkgNumAddsSettleFails checks that the number of adds and
|
|
||||||
// settle/fail log updates are correct.
|
|
||||||
func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg,
|
|
||||||
expectedNumAdds, expectedNumSettleFails int) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
if len(fwdPkg.Adds) != expectedNumAdds {
|
|
||||||
t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d",
|
|
||||||
line, expectedNumAdds, len(fwdPkg.Adds))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(fwdPkg.SettleFails) != expectedNumSettleFails {
|
|
||||||
t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d",
|
|
||||||
line, expectedNumSettleFails, len(fwdPkg.SettleFails))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our
|
|
||||||
// expected full-ness.
|
|
||||||
func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
if fwdPkg.AckFilter.IsFull() != expected {
|
|
||||||
t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+
|
|
||||||
"found %v", line, expected, fwdPkg.AckFilter.IsFull())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail
|
|
||||||
// filter matches our expected full-ness.
|
|
||||||
func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
if fwdPkg.SettleFailFilter.IsFull() != expected {
|
|
||||||
t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+
|
|
||||||
"found %v", line, expected, fwdPkg.SettleFailFilter.IsFull())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadFwdPkgs is a helper method that reads all forwarding packages for a
|
|
||||||
// particular packager.
|
|
||||||
func loadFwdPkgs(t *testing.T, db *bbolt.DB,
|
|
||||||
packager channeldb.FwdPackager) []*channeldb.FwdPkg {
|
|
||||||
|
|
||||||
var fwdPkgs []*channeldb.FwdPkg
|
|
||||||
if err := db.View(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
fwdPkgs, err = packager.LoadFwdPkgs(tx)
|
|
||||||
return err
|
|
||||||
}); err != nil {
|
|
||||||
t.Fatalf("unable to load fwd pkgs: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fwdPkgs
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeFwdPkgDB initializes a test database for forwarding packages. If the
|
|
||||||
// provided path is an empty, it will create a temp dir/file to use.
|
|
||||||
func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB {
|
|
||||||
if path == "" {
|
|
||||||
var err error
|
|
||||||
path, err = ioutil.TempDir("", "fwdpkgdb")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create temp path: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
path = filepath.Join(path, "fwdpkg.db")
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := bbolt.Open(path, 0600, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to open boltdb: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
@ -2,20 +2,15 @@ package migration_01_to_11
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"image/color"
|
"image/color"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/btcsuite/btcd/txscript"
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
"github.com/btcsuite/btcutil"
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/coreos/bbolt"
|
"github.com/coreos/bbolt"
|
||||||
@ -74,11 +69,6 @@ var (
|
|||||||
// lookup of incoming channel edges.
|
// lookup of incoming channel edges.
|
||||||
unknownPolicy = []byte{}
|
unknownPolicy = []byte{}
|
||||||
|
|
||||||
// chanStart is an array of all zero bytes which is used to perform
|
|
||||||
// range scans within the edgeBucket to obtain all of the outgoing
|
|
||||||
// edges for a particular node.
|
|
||||||
chanStart [8]byte
|
|
||||||
|
|
||||||
// edgeIndexBucket is an index which can be used to iterate all edges
|
// edgeIndexBucket is an index which can be used to iterate all edges
|
||||||
// in the bucket, grouping them according to their in/out nodes.
|
// in the bucket, grouping them according to their in/out nodes.
|
||||||
// Additionally, the items in this bucket also contain the complete
|
// Additionally, the items in this bucket also contain the complete
|
||||||
@ -155,9 +145,6 @@ const (
|
|||||||
// would be possible for a node to create a ton of updates and slowly
|
// would be possible for a node to create a ton of updates and slowly
|
||||||
// fill our disk, and also waste bandwidth due to relaying.
|
// fill our disk, and also waste bandwidth due to relaying.
|
||||||
MaxAllowedExtraOpaqueBytes = 10000
|
MaxAllowedExtraOpaqueBytes = 10000
|
||||||
|
|
||||||
// feeRateParts is the total number of parts used to express fee rates.
|
|
||||||
feeRateParts = 1e6
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ChannelGraph is a persistent, on-disk graph representation of the Lightning
|
// ChannelGraph is a persistent, on-disk graph representation of the Lightning
|
||||||
@ -172,10 +159,6 @@ const (
|
|||||||
// for that edge.
|
// for that edge.
|
||||||
type ChannelGraph struct {
|
type ChannelGraph struct {
|
||||||
db *DB
|
db *DB
|
||||||
|
|
||||||
cacheMu sync.RWMutex
|
|
||||||
rejectCache *rejectCache
|
|
||||||
chanCache *channelCache
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The
|
// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The
|
||||||
@ -183,189 +166,9 @@ type ChannelGraph struct {
|
|||||||
func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph {
|
func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph {
|
||||||
return &ChannelGraph{
|
return &ChannelGraph{
|
||||||
db: db,
|
db: db,
|
||||||
rejectCache: newRejectCache(rejectCacheSize),
|
|
||||||
chanCache: newChannelCache(chanCacheSize),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Database returns a pointer to the underlying database.
|
|
||||||
func (c *ChannelGraph) Database() *DB {
|
|
||||||
return c.db
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForEachChannel iterates through all the channel edges stored within the
|
|
||||||
// graph and invokes the passed callback for each edge. The callback takes two
|
|
||||||
// edges as since this is a directed graph, both the in/out edges are visited.
|
|
||||||
// If the callback returns an error, then the transaction is aborted and the
|
|
||||||
// iteration stops early.
|
|
||||||
//
|
|
||||||
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
|
|
||||||
// for that particular channel edge routing policy will be passed into the
|
|
||||||
// callback.
|
|
||||||
func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
|
|
||||||
// TODO(roasbeef): ptr map to reduce # of allocs? no duplicates
|
|
||||||
|
|
||||||
return c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// First, grab the node bucket. This will be used to populate
|
|
||||||
// the Node pointers in each edge read from disk.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, grab the edge bucket which stores the edges, and also
|
|
||||||
// the index itself so we can group the directed edges together
|
|
||||||
// logically.
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each edge pair within the edge index, we fetch each edge
|
|
||||||
// itself and also the node information in order to fully
|
|
||||||
// populated the object.
|
|
||||||
return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error {
|
|
||||||
infoReader := bytes.NewReader(edgeInfoBytes)
|
|
||||||
edgeInfo, err := deserializeChanEdgeInfo(infoReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edgeInfo.db = c.db
|
|
||||||
|
|
||||||
edge1, edge2, err := fetchChanEdgePolicies(
|
|
||||||
edgeIndex, edges, nodes, chanID, c.db,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With both edges read, execute the call back. IF this
|
|
||||||
// function returns an error then the transaction will
|
|
||||||
// be aborted.
|
|
||||||
return cb(&edgeInfo, edge1, edge2)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForEachNodeChannel iterates through all channels of a given node, executing the
|
|
||||||
// passed callback with an edge info structure and the policies of each end
|
|
||||||
// of the channel. The first edge policy is the outgoing edge *to* the
|
|
||||||
// the connecting node, while the second is the incoming edge *from* the
|
|
||||||
// connecting node. If the callback returns an error, then the iteration is
|
|
||||||
// halted with the error propagated back up to the caller.
|
|
||||||
//
|
|
||||||
// Unknown policies are passed into the callback as nil values.
|
|
||||||
//
|
|
||||||
// If the caller wishes to re-use an existing boltdb transaction, then it
|
|
||||||
// should be passed as the first argument. Otherwise the first argument should
|
|
||||||
// be nil and a fresh transaction will be created to execute the graph
|
|
||||||
// traversal.
|
|
||||||
func (c *ChannelGraph) ForEachNodeChannel(tx *bbolt.Tx, nodePub []byte,
|
|
||||||
cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy,
|
|
||||||
*ChannelEdgePolicy) error) error {
|
|
||||||
|
|
||||||
db := c.db
|
|
||||||
|
|
||||||
return nodeTraversal(tx, nodePub, db, cb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisabledChannelIDs returns the channel ids of disabled channels.
|
|
||||||
// A channel is disabled when two of the associated ChanelEdgePolicies
|
|
||||||
// have their disabled bit on.
|
|
||||||
func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) {
|
|
||||||
var disabledChanIDs []uint64
|
|
||||||
chanEdgeFound := make(map[uint64]struct{})
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
disabledEdgePolicyIndex := edges.Bucket(disabledEdgePolicyBucket)
|
|
||||||
if disabledEdgePolicyIndex == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// We iterate over all disabled policies and we add each channel that
|
|
||||||
// has more than one disabled policy to disabledChanIDs array.
|
|
||||||
return disabledEdgePolicyIndex.ForEach(func(k, v []byte) error {
|
|
||||||
chanID := byteOrder.Uint64(k[:8])
|
|
||||||
_, edgeFound := chanEdgeFound[chanID]
|
|
||||||
if edgeFound {
|
|
||||||
delete(chanEdgeFound, chanID)
|
|
||||||
disabledChanIDs = append(disabledChanIDs, chanID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chanEdgeFound[chanID] = struct{}{}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return disabledChanIDs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForEachNode iterates through all the stored vertices/nodes in the graph,
|
|
||||||
// executing the passed callback with each node encountered. If the callback
|
|
||||||
// returns an error, then the transaction is aborted and the iteration stops
|
|
||||||
// early.
|
|
||||||
//
|
|
||||||
// If the caller wishes to re-use an existing boltdb transaction, then it
|
|
||||||
// should be passed as the first argument. Otherwise the first argument should
|
|
||||||
// be nil and a fresh transaction will be created to execute the graph
|
|
||||||
// traversal
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): add iterator interface to allow for memory efficient graph
|
|
||||||
// traversal when graph gets mega
|
|
||||||
func (c *ChannelGraph) ForEachNode(tx *bbolt.Tx, cb func(*bbolt.Tx, *LightningNode) error) error {
|
|
||||||
traversal := func(tx *bbolt.Tx) error {
|
|
||||||
// First grab the nodes bucket which stores the mapping from
|
|
||||||
// pubKey to node information.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodes.ForEach(func(pubKey, nodeBytes []byte) error {
|
|
||||||
// If this is the source key, then we skip this
|
|
||||||
// iteration as the value for this key is a pubKey
|
|
||||||
// rather than raw node information.
|
|
||||||
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeReader := bytes.NewReader(nodeBytes)
|
|
||||||
node, err := deserializeLightningNode(nodeReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
node.db = c.db
|
|
||||||
|
|
||||||
// Execute the callback, the transaction will abort if
|
|
||||||
// this returns an error.
|
|
||||||
return cb(tx, &node)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no transaction was provided, then we'll create a new transaction
|
|
||||||
// to execute the transaction within.
|
|
||||||
if tx == nil {
|
|
||||||
return c.db.View(traversal)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we re-use the existing transaction to execute the graph
|
|
||||||
// traversal.
|
|
||||||
return traversal(tx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SourceNode returns the source node of the graph. The source node is treated
|
// SourceNode returns the source node of the graph. The source node is treated
|
||||||
// as the center node within a star-graph. This method may be used to kick off
|
// as the center node within a star-graph. This method may be used to kick off
|
||||||
// a path finding algorithm in order to explore the reachability of another
|
// a path finding algorithm in order to explore the reachability of another
|
||||||
@ -442,20 +245,6 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddLightningNode adds a vertex/node to the graph database. If the node is not
|
|
||||||
// in the database from before, this will add a new, unconnected one to the
|
|
||||||
// graph. If it is present from before, this will update that node's
|
|
||||||
// information. Note that this method is expected to only be called to update
|
|
||||||
// an already present node from a node announcement, or to insert a node found
|
|
||||||
// in a channel update.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): also need sig of announcement
|
|
||||||
func (c *ChannelGraph) AddLightningNode(node *LightningNode) error {
|
|
||||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return addLightningNode(tx, node)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func addLightningNode(tx *bbolt.Tx, node *LightningNode) error {
|
func addLightningNode(tx *bbolt.Tx, node *LightningNode) error {
|
||||||
nodes, err := tx.CreateBucketIfNotExists(nodeBucket)
|
nodes, err := tx.CreateBucketIfNotExists(nodeBucket)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -477,1487 +266,6 @@ func addLightningNode(tx *bbolt.Tx, node *LightningNode) error {
|
|||||||
return putLightningNode(nodes, aliases, updateIndex, node)
|
return putLightningNode(nodes, aliases, updateIndex, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
// LookupAlias attempts to return the alias as advertised by the target node.
|
|
||||||
// TODO(roasbeef): currently assumes that aliases are unique...
|
|
||||||
func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) {
|
|
||||||
var alias string
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
aliases := nodes.Bucket(aliasIndexBucket)
|
|
||||||
if aliases == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
nodePub := pub.SerializeCompressed()
|
|
||||||
a := aliases.Get(nodePub)
|
|
||||||
if a == nil {
|
|
||||||
return ErrNodeAliasNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(roasbeef): should actually be using the utf-8
|
|
||||||
// package...
|
|
||||||
alias = string(a)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return alias, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteLightningNode starts a new database transaction to remove a vertex/node
|
|
||||||
// from the database according to the node's public key.
|
|
||||||
func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error {
|
|
||||||
// TODO(roasbeef): ensure dangling edges are removed...
|
|
||||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.deleteLightningNode(
|
|
||||||
nodes, nodePub.SerializeCompressed(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteLightningNode uses an existing database transaction to remove a
|
|
||||||
// vertex/node from the database according to the node's public key.
|
|
||||||
func (c *ChannelGraph) deleteLightningNode(nodes *bbolt.Bucket,
|
|
||||||
compressedPubKey []byte) error {
|
|
||||||
|
|
||||||
aliases := nodes.Bucket(aliasIndexBucket)
|
|
||||||
if aliases == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := aliases.Delete(compressedPubKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Before we delete the node, we'll fetch its current state so we can
|
|
||||||
// determine when its last update was to clear out the node update
|
|
||||||
// index.
|
|
||||||
node, err := fetchLightningNode(nodes, compressedPubKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := nodes.Delete(compressedPubKey); err != nil {
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll delete the index entry for the node within the
|
|
||||||
// nodeUpdateIndexBucket as this node is no longer active, so we don't
|
|
||||||
// need to track its last update.
|
|
||||||
nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket)
|
|
||||||
if nodeUpdateIndex == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// In order to delete the entry, we'll need to reconstruct the key for
|
|
||||||
// its last update.
|
|
||||||
updateUnix := uint64(node.LastUpdate.Unix())
|
|
||||||
var indexKey [8 + 33]byte
|
|
||||||
byteOrder.PutUint64(indexKey[:8], updateUnix)
|
|
||||||
copy(indexKey[8:], compressedPubKey)
|
|
||||||
|
|
||||||
return nodeUpdateIndex.Delete(indexKey[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
|
|
||||||
// undirected edge from the two target nodes are created. The information
|
|
||||||
// stored denotes the static attributes of the channel, such as the channelID,
|
|
||||||
// the keys involved in creation of the channel, and the set of features that
|
|
||||||
// the channel supports. The chanPoint and chanID are used to uniquely identify
|
|
||||||
// the edge globally within the database.
|
|
||||||
func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error {
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return c.addChannelEdge(tx, edge)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.rejectCache.remove(edge.ChannelID)
|
|
||||||
c.chanCache.remove(edge.ChannelID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addChannelEdge is the private form of AddChannelEdge that allows callers to
|
|
||||||
// utilize an existing db transaction.
|
|
||||||
func (c *ChannelGraph) addChannelEdge(tx *bbolt.Tx, edge *ChannelEdgeInfo) error {
|
|
||||||
// Construct the channel's primary key which is the 8-byte channel ID.
|
|
||||||
var chanKey [8]byte
|
|
||||||
binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID)
|
|
||||||
|
|
||||||
nodes, err := tx.CreateBucketIfNotExists(nodeBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edges, err := tx.CreateBucketIfNotExists(edgeBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// First, attempt to check if this edge has already been created. If
|
|
||||||
// so, then we can exit early as this method is meant to be idempotent.
|
|
||||||
if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo != nil {
|
|
||||||
return ErrEdgeAlreadyExist
|
|
||||||
}
|
|
||||||
|
|
||||||
// Before we insert the channel into the database, we'll ensure that
|
|
||||||
// both nodes already exist in the channel graph. If either node
|
|
||||||
// doesn't, then we'll insert a "shell" node that just includes its
|
|
||||||
// public key, so subsequent validation and queries can work properly.
|
|
||||||
_, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:])
|
|
||||||
switch {
|
|
||||||
case node1Err == ErrGraphNodeNotFound:
|
|
||||||
node1Shell := LightningNode{
|
|
||||||
PubKeyBytes: edge.NodeKey1Bytes,
|
|
||||||
HaveNodeAnnouncement: false,
|
|
||||||
}
|
|
||||||
err := addLightningNode(tx, &node1Shell)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create shell node "+
|
|
||||||
"for: %x", edge.NodeKey1Bytes)
|
|
||||||
|
|
||||||
}
|
|
||||||
case node1Err != nil:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:])
|
|
||||||
switch {
|
|
||||||
case node2Err == ErrGraphNodeNotFound:
|
|
||||||
node2Shell := LightningNode{
|
|
||||||
PubKeyBytes: edge.NodeKey2Bytes,
|
|
||||||
HaveNodeAnnouncement: false,
|
|
||||||
}
|
|
||||||
err := addLightningNode(tx, &node2Shell)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to create shell node "+
|
|
||||||
"for: %x", edge.NodeKey2Bytes)
|
|
||||||
|
|
||||||
}
|
|
||||||
case node2Err != nil:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the edge hasn't been created yet, then we'll first add it to the
|
|
||||||
// edge index in order to associate the edge between two nodes and also
|
|
||||||
// store the static components of the channel.
|
|
||||||
if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark edge policies for both sides as unknown. This is to enable
|
|
||||||
// efficient incoming channel lookup for a node.
|
|
||||||
for _, key := range []*[33]byte{&edge.NodeKey1Bytes,
|
|
||||||
&edge.NodeKey2Bytes} {
|
|
||||||
|
|
||||||
err := putChanEdgePolicyUnknown(edges, edge.ChannelID,
|
|
||||||
key[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally we add it to the channel index which maps channel points
|
|
||||||
// (outpoints) to the shorter channel ID's.
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return chanIndex.Put(b.Bytes(), chanKey[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasChannelEdge returns true if the database knows of a channel edge with the
|
|
||||||
// passed channel ID, and false otherwise. If an edge with that ID is found
|
|
||||||
// within the graph, then two time stamps representing the last time the edge
|
|
||||||
// was updated for both directed edges are returned along with the boolean. If
|
|
||||||
// it is not found, then the zombie index is checked and its result is returned
|
|
||||||
// as the second boolean.
|
|
||||||
func (c *ChannelGraph) HasChannelEdge(
|
|
||||||
chanID uint64) (time.Time, time.Time, bool, bool, error) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
upd1Time time.Time
|
|
||||||
upd2Time time.Time
|
|
||||||
exists bool
|
|
||||||
isZombie bool
|
|
||||||
)
|
|
||||||
|
|
||||||
// We'll query the cache with the shared lock held to allow multiple
|
|
||||||
// readers to access values in the cache concurrently if they exist.
|
|
||||||
c.cacheMu.RLock()
|
|
||||||
if entry, ok := c.rejectCache.get(chanID); ok {
|
|
||||||
c.cacheMu.RUnlock()
|
|
||||||
upd1Time = time.Unix(entry.upd1Time, 0)
|
|
||||||
upd2Time = time.Unix(entry.upd2Time, 0)
|
|
||||||
exists, isZombie = entry.flags.unpack()
|
|
||||||
return upd1Time, upd2Time, exists, isZombie, nil
|
|
||||||
}
|
|
||||||
c.cacheMu.RUnlock()
|
|
||||||
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
// The item was not found with the shared lock, so we'll acquire the
|
|
||||||
// exclusive lock and check the cache again in case another method added
|
|
||||||
// the entry to the cache while no lock was held.
|
|
||||||
if entry, ok := c.rejectCache.get(chanID); ok {
|
|
||||||
upd1Time = time.Unix(entry.upd1Time, 0)
|
|
||||||
upd2Time = time.Unix(entry.upd2Time, 0)
|
|
||||||
exists, isZombie = entry.flags.unpack()
|
|
||||||
return upd1Time, upd2Time, exists, isZombie, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
var channelID [8]byte
|
|
||||||
byteOrder.PutUint64(channelID[:], chanID)
|
|
||||||
|
|
||||||
// If the edge doesn't exist, then we'll also check our zombie
|
|
||||||
// index.
|
|
||||||
if edgeIndex.Get(channelID[:]) == nil {
|
|
||||||
exists = false
|
|
||||||
zombieIndex := edges.Bucket(zombieBucket)
|
|
||||||
if zombieIndex != nil {
|
|
||||||
isZombie, _, _ = isZombieEdge(
|
|
||||||
zombieIndex, chanID,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
exists = true
|
|
||||||
isZombie = false
|
|
||||||
|
|
||||||
// If the channel has been found in the graph, then retrieve
|
|
||||||
// the edges itself so we can return the last updated
|
|
||||||
// timestamps.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, nodes,
|
|
||||||
channelID[:], c.db)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// As we may have only one of the edges populated, only set the
|
|
||||||
// update time if the edge was found in the database.
|
|
||||||
if e1 != nil {
|
|
||||||
upd1Time = e1.LastUpdate
|
|
||||||
}
|
|
||||||
if e2 != nil {
|
|
||||||
upd2Time = e2.LastUpdate
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return time.Time{}, time.Time{}, exists, isZombie, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.rejectCache.insert(chanID, rejectCacheEntry{
|
|
||||||
upd1Time: upd1Time.Unix(),
|
|
||||||
upd2Time: upd2Time.Unix(),
|
|
||||||
flags: packRejectFlags(exists, isZombie),
|
|
||||||
})
|
|
||||||
|
|
||||||
return upd1Time, upd2Time, exists, isZombie, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateChannelEdge retrieves and update edge of the graph database. Method
|
|
||||||
// only reserved for updating an edge info after its already been created.
|
|
||||||
// In order to maintain this constraints, we return an error in the scenario
|
|
||||||
// that an edge info hasn't yet been created yet, but someone attempts to update
|
|
||||||
// it.
|
|
||||||
func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error {
|
|
||||||
// Construct the channel's primary key which is the 8-byte channel ID.
|
|
||||||
var chanKey [8]byte
|
|
||||||
binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID)
|
|
||||||
|
|
||||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edge == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return putChanEdgeInfo(edgeIndex, edge, chanKey)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// pruneTipBytes is the total size of the value which stores a prune
|
|
||||||
// entry of the graph in the prune log. The "prune tip" is the last
|
|
||||||
// entry in the prune log, and indicates if the channel graph is in
|
|
||||||
// sync with the current UTXO state. The structure of the value
|
|
||||||
// is: blockHash, taking 32 bytes total.
|
|
||||||
pruneTipBytes = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// PruneGraph prunes newly closed channels from the channel graph in response
|
|
||||||
// to a new block being solved on the network. Any transactions which spend the
|
|
||||||
// funding output of any known channels within he graph will be deleted.
|
|
||||||
// Additionally, the "prune tip", or the last block which has been used to
|
|
||||||
// prune the graph is stored so callers can ensure the graph is fully in sync
|
|
||||||
// with the current UTXO state. A slice of channels that have been closed by
|
|
||||||
// the target block are returned if the function succeeds without error.
|
|
||||||
func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint,
|
|
||||||
blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo, error) {
|
|
||||||
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
var chansClosed []*ChannelEdgeInfo
|
|
||||||
|
|
||||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
// First grab the edges bucket which houses the information
|
|
||||||
// we'd like to delete
|
|
||||||
edges, err := tx.CreateBucketIfNotExists(edgeBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next grab the two edge indexes which will also need to be updated.
|
|
||||||
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrSourceNodeNotSet
|
|
||||||
}
|
|
||||||
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each of the outpoints that have been spent within the
|
|
||||||
// block, we attempt to delete them from the graph as if that
|
|
||||||
// outpoint was a channel, then it has now been closed.
|
|
||||||
for _, chanPoint := range spentOutputs {
|
|
||||||
// TODO(roasbeef): load channel bloom filter, continue
|
|
||||||
// if NOT if filter
|
|
||||||
|
|
||||||
var opBytes bytes.Buffer
|
|
||||||
if err := writeOutpoint(&opBytes, chanPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// First attempt to see if the channel exists within
|
|
||||||
// the database, if not, then we can exit early.
|
|
||||||
chanID := chanIndex.Get(opBytes.Bytes())
|
|
||||||
if chanID == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// However, if it does, then we'll read out the full
|
|
||||||
// version so we can add it to the set of deleted
|
|
||||||
// channels.
|
|
||||||
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to delete the channel, an ErrEdgeNotFound
|
|
||||||
// will be returned if that outpoint isn't known to be
|
|
||||||
// a channel. If no error is returned, then a channel
|
|
||||||
// was successfully pruned.
|
|
||||||
err = delChannelEdge(
|
|
||||||
edges, edgeIndex, chanIndex, zombieIndex, nodes,
|
|
||||||
chanID, false,
|
|
||||||
)
|
|
||||||
if err != nil && err != ErrEdgeNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chansClosed = append(chansClosed, &edgeInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the graph pruned, add a new entry to the prune log,
|
|
||||||
// which can be used to check if the graph is fully synced with
|
|
||||||
// the current UTXO state.
|
|
||||||
var blockHeightBytes [4]byte
|
|
||||||
byteOrder.PutUint32(blockHeightBytes[:], blockHeight)
|
|
||||||
|
|
||||||
var newTip [pruneTipBytes]byte
|
|
||||||
copy(newTip[:], blockHash[:])
|
|
||||||
|
|
||||||
err = pruneBucket.Put(blockHeightBytes[:], newTip[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the graph has been pruned, we'll also attempt to
|
|
||||||
// prune any nodes that have had a channel closed within the
|
|
||||||
// latest block.
|
|
||||||
return c.pruneGraphNodes(nodes, edgeIndex)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, channel := range chansClosed {
|
|
||||||
c.rejectCache.remove(channel.ChannelID)
|
|
||||||
c.chanCache.remove(channel.ChannelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return chansClosed, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PruneGraphNodes is a garbage collection method which attempts to prune out
|
|
||||||
// any nodes from the channel graph that are currently unconnected. This ensure
|
|
||||||
// that we only maintain a graph of reachable nodes. In the event that a pruned
|
|
||||||
// node gains more channels, it will be re-added back to the graph.
|
|
||||||
func (c *ChannelGraph) PruneGraphNodes() error {
|
|
||||||
return c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.pruneGraphNodes(nodes, edgeIndex)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// pruneGraphNodes attempts to remove any nodes from the graph who have had a
|
|
||||||
// channel closed within the current block. If the node still has existing
|
|
||||||
// channels in the graph, this will act as a no-op.
|
|
||||||
func (c *ChannelGraph) pruneGraphNodes(nodes *bbolt.Bucket,
|
|
||||||
edgeIndex *bbolt.Bucket) error {
|
|
||||||
|
|
||||||
log.Trace("Pruning nodes from graph with no open channels")
|
|
||||||
|
|
||||||
// We'll retrieve the graph's source node to ensure we don't remove it
|
|
||||||
// even if it no longer has any open channels.
|
|
||||||
sourceNode, err := c.sourceNode(nodes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll use this map to keep count the number of references to a node
|
|
||||||
// in the graph. A node should only be removed once it has no more
|
|
||||||
// references in the graph.
|
|
||||||
nodeRefCounts := make(map[[33]byte]int)
|
|
||||||
err = nodes.ForEach(func(pubKey, nodeBytes []byte) error {
|
|
||||||
// If this is the source key, then we skip this
|
|
||||||
// iteration as the value for this key is a pubKey
|
|
||||||
// rather than raw node information.
|
|
||||||
if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var nodePub [33]byte
|
|
||||||
copy(nodePub[:], pubKey)
|
|
||||||
nodeRefCounts[nodePub] = 0
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// To ensure we never delete the source node, we'll start off by
|
|
||||||
// bumping its ref count to 1.
|
|
||||||
nodeRefCounts[sourceNode.PubKeyBytes] = 1
|
|
||||||
|
|
||||||
// Next, we'll run through the edgeIndex which maps a channel ID to the
|
|
||||||
// edge info. We'll use this scan to populate our reference count map
|
|
||||||
// above.
|
|
||||||
err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error {
|
|
||||||
// The first 66 bytes of the edge info contain the pubkeys of
|
|
||||||
// the nodes that this edge attaches. We'll extract them, and
|
|
||||||
// add them to the ref count map.
|
|
||||||
var node1, node2 [33]byte
|
|
||||||
copy(node1[:], edgeInfoBytes[:33])
|
|
||||||
copy(node2[:], edgeInfoBytes[33:])
|
|
||||||
|
|
||||||
// With the nodes extracted, we'll increase the ref count of
|
|
||||||
// each of the nodes.
|
|
||||||
nodeRefCounts[node1]++
|
|
||||||
nodeRefCounts[node2]++
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll make a second pass over the set of nodes, and delete
|
|
||||||
// any nodes that have a ref count of zero.
|
|
||||||
var numNodesPruned int
|
|
||||||
for nodePubKey, refCount := range nodeRefCounts {
|
|
||||||
// If the ref count of the node isn't zero, then we can safely
|
|
||||||
// skip it as it still has edges to or from it within the
|
|
||||||
// graph.
|
|
||||||
if refCount != 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we reach this point, then there are no longer any edges
|
|
||||||
// that connect this node, so we can delete it.
|
|
||||||
if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil {
|
|
||||||
log.Warnf("Unable to prune node %x from the "+
|
|
||||||
"graph: %v", nodePubKey, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Pruned unconnected node %x from channel graph",
|
|
||||||
nodePubKey[:])
|
|
||||||
|
|
||||||
numNodesPruned++
|
|
||||||
}
|
|
||||||
|
|
||||||
if numNodesPruned > 0 {
|
|
||||||
log.Infof("Pruned %v unconnected nodes from the channel graph",
|
|
||||||
numNodesPruned)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DisconnectBlockAtHeight is used to indicate that the block specified
|
|
||||||
// by the passed height has been disconnected from the main chain. This
|
|
||||||
// will "rewind" the graph back to the height below, deleting channels
|
|
||||||
// that are no longer confirmed from the graph. The prune log will be
|
|
||||||
// set to the last prune height valid for the remaining chain.
|
|
||||||
// Channels that were removed from the graph resulting from the
|
|
||||||
// disconnected block are returned.
|
|
||||||
func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo,
|
|
||||||
error) {
|
|
||||||
|
|
||||||
// Every channel having a ShortChannelID starting at 'height'
|
|
||||||
// will no longer be confirmed.
|
|
||||||
startShortChanID := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: height,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete everything after this height from the db.
|
|
||||||
endShortChanID := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: math.MaxUint32 & 0x00ffffff,
|
|
||||||
TxIndex: math.MaxUint32 & 0x00ffffff,
|
|
||||||
TxPosition: math.MaxUint16,
|
|
||||||
}
|
|
||||||
// The block height will be the 3 first bytes of the channel IDs.
|
|
||||||
var chanIDStart [8]byte
|
|
||||||
byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
|
|
||||||
var chanIDEnd [8]byte
|
|
||||||
byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
|
|
||||||
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
// Keep track of the channels that are removed from the graph.
|
|
||||||
var removedChans []*ChannelEdgeInfo
|
|
||||||
|
|
||||||
if err := c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
edges, err := tx.CreateBucketIfNotExists(edgeBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nodes, err := tx.CreateBucketIfNotExists(nodeBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan from chanIDStart to chanIDEnd, deleting every
|
|
||||||
// found edge.
|
|
||||||
// NOTE: we must delete the edges after the cursor loop, since
|
|
||||||
// modifying the bucket while traversing is not safe.
|
|
||||||
var keys [][]byte
|
|
||||||
cursor := edgeIndex.Cursor()
|
|
||||||
for k, v := cursor.Seek(chanIDStart[:]); k != nil &&
|
|
||||||
bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() {
|
|
||||||
|
|
||||||
edgeInfoReader := bytes.NewReader(v)
|
|
||||||
edgeInfo, err := deserializeChanEdgeInfo(edgeInfoReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
keys = append(keys, k)
|
|
||||||
removedChans = append(removedChans, &edgeInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, k := range keys {
|
|
||||||
err = delChannelEdge(
|
|
||||||
edges, edgeIndex, chanIndex, zombieIndex, nodes,
|
|
||||||
k, false,
|
|
||||||
)
|
|
||||||
if err != nil && err != ErrEdgeNotFound {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete all the entries in the prune log having a height
|
|
||||||
// greater or equal to the block disconnected.
|
|
||||||
metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var pruneKeyStart [4]byte
|
|
||||||
byteOrder.PutUint32(pruneKeyStart[:], height)
|
|
||||||
|
|
||||||
var pruneKeyEnd [4]byte
|
|
||||||
byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32)
|
|
||||||
|
|
||||||
// To avoid modifying the bucket while traversing, we delete
|
|
||||||
// the keys in a second loop.
|
|
||||||
var pruneKeys [][]byte
|
|
||||||
pruneCursor := pruneBucket.Cursor()
|
|
||||||
for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil &&
|
|
||||||
bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() {
|
|
||||||
|
|
||||||
pruneKeys = append(pruneKeys, k)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, k := range pruneKeys {
|
|
||||||
if err := pruneBucket.Delete(k); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, channel := range removedChans {
|
|
||||||
c.rejectCache.remove(channel.ChannelID)
|
|
||||||
c.chanCache.remove(channel.ChannelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return removedChans, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PruneTip returns the block height and hash of the latest block that has been
|
|
||||||
// used to prune channels in the graph. Knowing the "prune tip" allows callers
|
|
||||||
// to tell if the graph is currently in sync with the current best known UTXO
|
|
||||||
// state.
|
|
||||||
func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) {
|
|
||||||
var (
|
|
||||||
tipHash chainhash.Hash
|
|
||||||
tipHeight uint32
|
|
||||||
)
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
graphMeta := tx.Bucket(graphMetaBucket)
|
|
||||||
if graphMeta == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
pruneBucket := graphMeta.Bucket(pruneLogBucket)
|
|
||||||
if pruneBucket == nil {
|
|
||||||
return ErrGraphNeverPruned
|
|
||||||
}
|
|
||||||
|
|
||||||
pruneCursor := pruneBucket.Cursor()
|
|
||||||
|
|
||||||
// The prune key with the largest block height will be our
|
|
||||||
// prune tip.
|
|
||||||
k, v := pruneCursor.Last()
|
|
||||||
if k == nil {
|
|
||||||
return ErrGraphNeverPruned
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once we have the prune tip, the value will be the block hash,
|
|
||||||
// and the key the block height.
|
|
||||||
copy(tipHash[:], v[:])
|
|
||||||
tipHeight = byteOrder.Uint32(k[:])
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &tipHash, tipHeight, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteChannelEdges removes edges with the given channel IDs from the database
|
|
||||||
// and marks them as zombies. This ensures that we're unable to re-add it to our
|
|
||||||
// database once again. If an edge does not exist within the database, then
|
|
||||||
// ErrEdgeNotFound will be returned.
|
|
||||||
func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error {
|
|
||||||
// TODO(roasbeef): possibly delete from node bucket if node has no more
|
|
||||||
// channels
|
|
||||||
// TODO(roasbeef): don't delete both edges?
|
|
||||||
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
chanIndex := edges.Bucket(channelPointBucket)
|
|
||||||
if chanIndex == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodeNotFound
|
|
||||||
}
|
|
||||||
zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var rawChanID [8]byte
|
|
||||||
for _, chanID := range chanIDs {
|
|
||||||
byteOrder.PutUint64(rawChanID[:], chanID)
|
|
||||||
err := delChannelEdge(
|
|
||||||
edges, edgeIndex, chanIndex, zombieIndex, nodes,
|
|
||||||
rawChanID[:], true,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chanID := range chanIDs {
|
|
||||||
c.rejectCache.remove(chanID)
|
|
||||||
c.chanCache.remove(chanID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the
|
|
||||||
// passed channel point (outpoint). If the passed channel doesn't exist within
|
|
||||||
// the database, then ErrEdgeNotFound is returned.
|
|
||||||
func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) {
|
|
||||||
var chanID uint64
|
|
||||||
if err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
chanID, err = getChanID(tx, chanPoint)
|
|
||||||
return err
|
|
||||||
}); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getChanID returns the assigned channel ID for a given channel point.
|
|
||||||
func getChanID(tx *bbolt.Tx, chanPoint *wire.OutPoint) (uint64, error) {
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := writeOutpoint(&b, chanPoint); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return 0, ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
chanIndex := edges.Bucket(channelPointBucket)
|
|
||||||
if chanIndex == nil {
|
|
||||||
return 0, ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
chanIDBytes := chanIndex.Get(b.Bytes())
|
|
||||||
if chanIDBytes == nil {
|
|
||||||
return 0, ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
chanID := byteOrder.Uint64(chanIDBytes)
|
|
||||||
|
|
||||||
return chanID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(roasbeef): allow updates to use Batch?
|
|
||||||
|
|
||||||
// HighestChanID returns the "highest" known channel ID in the channel graph.
|
|
||||||
// This represents the "newest" channel from the PoV of the chain. This method
|
|
||||||
// can be used by peers to quickly determine if they're graphs are in sync.
|
|
||||||
func (c *ChannelGraph) HighestChanID() (uint64, error) {
|
|
||||||
var cid uint64
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// In order to find the highest chan ID, we'll fetch a cursor
|
|
||||||
// and use that to seek to the "end" of our known rage.
|
|
||||||
cidCursor := edgeIndex.Cursor()
|
|
||||||
|
|
||||||
lastChanID, _ := cidCursor.Last()
|
|
||||||
|
|
||||||
// If there's no key, then this means that we don't actually
|
|
||||||
// know of any channels, so we'll return a predicable error.
|
|
||||||
if lastChanID == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we'll de serialize the channel ID and return it
|
|
||||||
// to the caller.
|
|
||||||
cid = byteOrder.Uint64(lastChanID)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil && err != ErrGraphNoEdgesFound {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cid, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelEdge represents the complete set of information for a channel edge in
|
|
||||||
// the known channel graph. This struct couples the core information of the
|
|
||||||
// edge as well as each of the known advertised edge policies.
|
|
||||||
type ChannelEdge struct {
|
|
||||||
// Info contains all the static information describing the channel.
|
|
||||||
Info *ChannelEdgeInfo
|
|
||||||
|
|
||||||
// Policy1 points to the "first" edge policy of the channel containing
|
|
||||||
// the dynamic information required to properly route through the edge.
|
|
||||||
Policy1 *ChannelEdgePolicy
|
|
||||||
|
|
||||||
// Policy2 points to the "second" edge policy of the channel containing
|
|
||||||
// the dynamic information required to properly route through the edge.
|
|
||||||
Policy2 *ChannelEdgePolicy
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChanUpdatesInHorizon returns all the known channel edges which have at least
|
|
||||||
// one edge that has an update timestamp within the specified horizon.
|
|
||||||
func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) {
|
|
||||||
// To ensure we don't return duplicate ChannelEdges, we'll use an
|
|
||||||
// additional map to keep track of the edges already seen to prevent
|
|
||||||
// re-adding it.
|
|
||||||
edgesSeen := make(map[uint64]struct{})
|
|
||||||
edgesToCache := make(map[uint64]ChannelEdge)
|
|
||||||
var edgesInHorizon []ChannelEdge
|
|
||||||
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
var hits int
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket)
|
|
||||||
if edgeUpdateIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now obtain a cursor to perform a range query within
|
|
||||||
// the index to find all channels within the horizon.
|
|
||||||
updateCursor := edgeUpdateIndex.Cursor()
|
|
||||||
|
|
||||||
var startTimeBytes, endTimeBytes [8 + 8]byte
|
|
||||||
byteOrder.PutUint64(
|
|
||||||
startTimeBytes[:8], uint64(startTime.Unix()),
|
|
||||||
)
|
|
||||||
byteOrder.PutUint64(
|
|
||||||
endTimeBytes[:8], uint64(endTime.Unix()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// With our start and end times constructed, we'll step through
|
|
||||||
// the index collecting the info and policy of each update of
|
|
||||||
// each channel that has a last update within the time range.
|
|
||||||
for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil &&
|
|
||||||
bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() {
|
|
||||||
|
|
||||||
// We have a new eligible entry, so we'll slice of the
|
|
||||||
// chan ID so we can query it in the DB.
|
|
||||||
chanID := indexKey[8:]
|
|
||||||
|
|
||||||
// If we've already retrieved the info and policies for
|
|
||||||
// this edge, then we can skip it as we don't need to do
|
|
||||||
// so again.
|
|
||||||
chanIDInt := byteOrder.Uint64(chanID)
|
|
||||||
if _, ok := edgesSeen[chanIDInt]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if channel, ok := c.chanCache.get(chanIDInt); ok {
|
|
||||||
hits++
|
|
||||||
edgesSeen[chanIDInt] = struct{}{}
|
|
||||||
edgesInHorizon = append(edgesInHorizon, channel)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// First, we'll fetch the static edge information.
|
|
||||||
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
|
|
||||||
if err != nil {
|
|
||||||
chanID := byteOrder.Uint64(chanID)
|
|
||||||
return fmt.Errorf("unable to fetch info for "+
|
|
||||||
"edge with chan_id=%v: %v", chanID, err)
|
|
||||||
}
|
|
||||||
edgeInfo.db = c.db
|
|
||||||
|
|
||||||
// With the static information obtained, we'll now
|
|
||||||
// fetch the dynamic policy info.
|
|
||||||
edge1, edge2, err := fetchChanEdgePolicies(
|
|
||||||
edgeIndex, edges, nodes, chanID, c.db,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
chanID := byteOrder.Uint64(chanID)
|
|
||||||
return fmt.Errorf("unable to fetch policies "+
|
|
||||||
"for edge with chan_id=%v: %v", chanID,
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll collate this edge with the rest of
|
|
||||||
// edges to be returned.
|
|
||||||
edgesSeen[chanIDInt] = struct{}{}
|
|
||||||
channel := ChannelEdge{
|
|
||||||
Info: &edgeInfo,
|
|
||||||
Policy1: edge1,
|
|
||||||
Policy2: edge2,
|
|
||||||
}
|
|
||||||
edgesInHorizon = append(edgesInHorizon, channel)
|
|
||||||
edgesToCache[chanIDInt] = channel
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
switch {
|
|
||||||
case err == ErrGraphNoEdgesFound:
|
|
||||||
fallthrough
|
|
||||||
case err == ErrGraphNodesNotFound:
|
|
||||||
break
|
|
||||||
|
|
||||||
case err != nil:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert any edges loaded from disk into the cache.
|
|
||||||
for chanid, channel := range edgesToCache {
|
|
||||||
c.chanCache.insert(chanid, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)",
|
|
||||||
float64(hits)/float64(len(edgesInHorizon)), hits,
|
|
||||||
len(edgesInHorizon))
|
|
||||||
|
|
||||||
return edgesInHorizon, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NodeUpdatesInHorizon returns all the known lightning node which have an
|
|
||||||
// update timestamp within the passed range. This method can be used by two
|
|
||||||
// nodes to quickly determine if they have the same set of up to date node
|
|
||||||
// announcements.
|
|
||||||
func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) {
|
|
||||||
var nodesInHorizon []LightningNode
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket)
|
|
||||||
if nodeUpdateIndex == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now obtain a cursor to perform a range query within
|
|
||||||
// the index to find all node announcements within the horizon.
|
|
||||||
updateCursor := nodeUpdateIndex.Cursor()
|
|
||||||
|
|
||||||
var startTimeBytes, endTimeBytes [8 + 33]byte
|
|
||||||
byteOrder.PutUint64(
|
|
||||||
startTimeBytes[:8], uint64(startTime.Unix()),
|
|
||||||
)
|
|
||||||
byteOrder.PutUint64(
|
|
||||||
endTimeBytes[:8], uint64(endTime.Unix()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// With our start and end times constructed, we'll step through
|
|
||||||
// the index collecting info for each node within the time
|
|
||||||
// range.
|
|
||||||
for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil &&
|
|
||||||
bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() {
|
|
||||||
|
|
||||||
nodePub := indexKey[8:]
|
|
||||||
node, err := fetchLightningNode(nodes, nodePub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
node.db = c.db
|
|
||||||
|
|
||||||
nodesInHorizon = append(nodesInHorizon, node)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
switch {
|
|
||||||
case err == ErrGraphNoEdgesFound:
|
|
||||||
fallthrough
|
|
||||||
case err == ErrGraphNodesNotFound:
|
|
||||||
break
|
|
||||||
|
|
||||||
case err != nil:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodesInHorizon, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
|
|
||||||
// ID's that we don't know and are not known zombies of the passed set. In other
|
|
||||||
// words, we perform a set difference of our set of chan ID's and the ones
|
|
||||||
// passed in. This method can be used by callers to determine the set of
|
|
||||||
// channels another peer knows of that we don't.
|
|
||||||
func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) {
|
|
||||||
var newChanIDs []uint64
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch the zombie index, it may not exist if no edges have
|
|
||||||
// ever been marked as zombies. If the index has been
|
|
||||||
// initialized, we will use it later to skip known zombie edges.
|
|
||||||
zombieIndex := edges.Bucket(zombieBucket)
|
|
||||||
|
|
||||||
// We'll run through the set of chanIDs and collate only the
|
|
||||||
// set of channel that are unable to be found within our db.
|
|
||||||
var cidBytes [8]byte
|
|
||||||
for _, cid := range chanIDs {
|
|
||||||
byteOrder.PutUint64(cidBytes[:], cid)
|
|
||||||
|
|
||||||
// If the edge is already known, skip it.
|
|
||||||
if v := edgeIndex.Get(cidBytes[:]); v != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the edge is a known zombie, skip it.
|
|
||||||
if zombieIndex != nil {
|
|
||||||
isZombie, _, _ := isZombieEdge(zombieIndex, cid)
|
|
||||||
if isZombie {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newChanIDs = append(newChanIDs, cid)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
switch {
|
|
||||||
// If we don't know of any edges yet, then we'll return the entire set
|
|
||||||
// of chan IDs specified.
|
|
||||||
case err == ErrGraphNoEdgesFound:
|
|
||||||
return chanIDs, nil
|
|
||||||
|
|
||||||
case err != nil:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return newChanIDs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterChannelRange returns the channel ID's of all known channels which were
|
|
||||||
// mined in a block height within the passed range. This method can be used to
|
|
||||||
// quickly share with a peer the set of channels we know of within a particular
|
|
||||||
// range to catch them up after a period of time offline.
|
|
||||||
func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) {
|
|
||||||
var chanIDs []uint64
|
|
||||||
|
|
||||||
startChanID := &lnwire.ShortChannelID{
|
|
||||||
BlockHeight: startHeight,
|
|
||||||
}
|
|
||||||
|
|
||||||
endChanID := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: endHeight,
|
|
||||||
TxIndex: math.MaxUint32 & 0x00ffffff,
|
|
||||||
TxPosition: math.MaxUint16,
|
|
||||||
}
|
|
||||||
|
|
||||||
// As we need to perform a range scan, we'll convert the starting and
|
|
||||||
// ending height to their corresponding values when encoded using short
|
|
||||||
// channel ID's.
|
|
||||||
var chanIDStart, chanIDEnd [8]byte
|
|
||||||
byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64())
|
|
||||||
byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64())
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
cursor := edgeIndex.Cursor()
|
|
||||||
|
|
||||||
// We'll now iterate through the database, and find each
|
|
||||||
// channel ID that resides within the specified range.
|
|
||||||
var cid uint64
|
|
||||||
for k, _ := cursor.Seek(chanIDStart[:]); k != nil &&
|
|
||||||
bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() {
|
|
||||||
|
|
||||||
// This channel ID rests within the target range, so
|
|
||||||
// we'll convert it into an integer and add it to our
|
|
||||||
// returned set.
|
|
||||||
cid = byteOrder.Uint64(k)
|
|
||||||
chanIDs = append(chanIDs, cid)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
switch {
|
|
||||||
// If we don't know of any channels yet, then there's nothing to
|
|
||||||
// filter, so we'll return an empty slice.
|
|
||||||
case err == ErrGraphNoEdgesFound:
|
|
||||||
return chanIDs, nil
|
|
||||||
|
|
||||||
case err != nil:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanIDs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchChanInfos returns the set of channel edges that correspond to the passed
|
|
||||||
// channel ID's. If an edge is the query is unknown to the database, it will
|
|
||||||
// skipped and the result will contain only those edges that exist at the time
|
|
||||||
// of the query. This can be used to respond to peer queries that are seeking to
|
|
||||||
// fill in gaps in their view of the channel graph.
|
|
||||||
func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
|
|
||||||
// TODO(roasbeef): sort cids?
|
|
||||||
|
|
||||||
var (
|
|
||||||
chanEdges []ChannelEdge
|
|
||||||
cidBytes [8]byte
|
|
||||||
)
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, cid := range chanIDs {
|
|
||||||
byteOrder.PutUint64(cidBytes[:], cid)
|
|
||||||
|
|
||||||
// First, we'll fetch the static edge information. If
|
|
||||||
// the edge is unknown, we will skip the edge and
|
|
||||||
// continue gathering all known edges.
|
|
||||||
edgeInfo, err := fetchChanEdgeInfo(
|
|
||||||
edgeIndex, cidBytes[:],
|
|
||||||
)
|
|
||||||
switch {
|
|
||||||
case err == ErrEdgeNotFound:
|
|
||||||
continue
|
|
||||||
case err != nil:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edgeInfo.db = c.db
|
|
||||||
|
|
||||||
// With the static information obtained, we'll now
|
|
||||||
// fetch the dynamic policy info.
|
|
||||||
edge1, edge2, err := fetchChanEdgePolicies(
|
|
||||||
edgeIndex, edges, nodes, cidBytes[:], c.db,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
chanEdges = append(chanEdges, ChannelEdge{
|
|
||||||
Info: &edgeInfo,
|
|
||||||
Policy1: edge1,
|
|
||||||
Policy2: edge2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return chanEdges, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delEdgeUpdateIndexEntry(edgesBucket *bbolt.Bucket, chanID uint64,
|
|
||||||
edge1, edge2 *ChannelEdgePolicy) error {
|
|
||||||
|
|
||||||
// First, we'll fetch the edge update index bucket which currently
|
|
||||||
// stores an entry for the channel we're about to delete.
|
|
||||||
updateIndex := edgesBucket.Bucket(edgeUpdateIndexBucket)
|
|
||||||
if updateIndex == nil {
|
|
||||||
// No edges in bucket, return early.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that we have the bucket, we'll attempt to construct a template
|
|
||||||
// for the index key: updateTime || chanid.
|
|
||||||
var indexKey [8 + 8]byte
|
|
||||||
byteOrder.PutUint64(indexKey[8:], chanID)
|
|
||||||
|
|
||||||
// With the template constructed, we'll attempt to delete an entry that
|
|
||||||
// would have been created by both edges: we'll alternate the update
|
|
||||||
// times, as one may had overridden the other.
|
|
||||||
if edge1 != nil {
|
|
||||||
byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix()))
|
|
||||||
if err := updateIndex.Delete(indexKey[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll also attempt to delete the entry that may have been created by
|
|
||||||
// the second edge.
|
|
||||||
if edge2 != nil {
|
|
||||||
byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix()))
|
|
||||||
if err := updateIndex.Delete(indexKey[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex,
|
|
||||||
nodes *bbolt.Bucket, chanID []byte, isZombie bool) error {
|
|
||||||
|
|
||||||
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll also remove the entry in the edge update index bucket before
|
|
||||||
// we delete the edges themselves so we can access their last update
|
|
||||||
// times.
|
|
||||||
cid := byteOrder.Uint64(chanID)
|
|
||||||
edge1, edge2, err := fetchChanEdgePolicies(
|
|
||||||
edgeIndex, edges, nodes, chanID, nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The edge key is of the format pubKey || chanID. First we construct
|
|
||||||
// the latter half, populating the channel ID.
|
|
||||||
var edgeKey [33 + 8]byte
|
|
||||||
copy(edgeKey[33:], chanID)
|
|
||||||
|
|
||||||
// With the latter half constructed, copy over the first public key to
|
|
||||||
// delete the edge in this direction, then the second to delete the
|
|
||||||
// edge in the opposite direction.
|
|
||||||
copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:])
|
|
||||||
if edges.Get(edgeKey[:]) != nil {
|
|
||||||
if err := edges.Delete(edgeKey[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:])
|
|
||||||
if edges.Get(edgeKey[:]) != nil {
|
|
||||||
if err := edges.Delete(edgeKey[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// As part of deleting the edge we also remove all disabled entries
|
|
||||||
// from the edgePolicyDisabledIndex bucket. We do that for both directions.
|
|
||||||
updateEdgePolicyDisabledIndex(edges, cid, false, false)
|
|
||||||
updateEdgePolicyDisabledIndex(edges, cid, true, false)
|
|
||||||
|
|
||||||
// With the edge data deleted, we can purge the information from the two
|
|
||||||
// edge indexes.
|
|
||||||
if err := edgeIndex.Delete(chanID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := chanIndex.Delete(b.Bytes()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll mark the edge as a zombie within our index if it's
|
|
||||||
// being removed due to the channel becoming a zombie. We do this to
|
|
||||||
// ensure we don't store unnecessary data for spent channels.
|
|
||||||
if !isZombie {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return markEdgeZombie(
|
|
||||||
zombieIndex, byteOrder.Uint64(chanID), edgeInfo.NodeKey1Bytes,
|
|
||||||
edgeInfo.NodeKey2Bytes,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateEdgePolicy updates the edge routing policy for a single directed edge
|
|
||||||
// within the database for the referenced channel. The `flags` attribute within
|
|
||||||
// the ChannelEdgePolicy determines which of the directed edges are being
|
|
||||||
// updated. If the flag is 1, then the first node's information is being
|
|
||||||
// updated, otherwise it's the second node's information. The node ordering is
|
|
||||||
// determined by the lexicographical ordering of the identity public keys of
|
|
||||||
// the nodes on either side of the channel.
|
|
||||||
func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error {
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
var isUpdate1 bool
|
|
||||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
isUpdate1, err = updateEdgePolicy(tx, edge)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If an entry for this channel is found in reject cache, we'll modify
|
|
||||||
// the entry with the updated timestamp for the direction that was just
|
|
||||||
// written. If the edge doesn't exist, we'll load the cache entry lazily
|
|
||||||
// during the next query for this edge.
|
|
||||||
if entry, ok := c.rejectCache.get(edge.ChannelID); ok {
|
|
||||||
if isUpdate1 {
|
|
||||||
entry.upd1Time = edge.LastUpdate.Unix()
|
|
||||||
} else {
|
|
||||||
entry.upd2Time = edge.LastUpdate.Unix()
|
|
||||||
}
|
|
||||||
c.rejectCache.insert(edge.ChannelID, entry)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If an entry for this channel is found in channel cache, we'll modify
|
|
||||||
// the entry with the updated policy for the direction that was just
|
|
||||||
// written. If the edge doesn't exist, we'll defer loading the info and
|
|
||||||
// policies and lazily read from disk during the next query.
|
|
||||||
if channel, ok := c.chanCache.get(edge.ChannelID); ok {
|
|
||||||
if isUpdate1 {
|
|
||||||
channel.Policy1 = edge
|
|
||||||
} else {
|
|
||||||
channel.Policy2 = edge
|
|
||||||
}
|
|
||||||
c.chanCache.insert(edge.ChannelID, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateEdgePolicy attempts to update an edge's policy within the relevant
|
// updateEdgePolicy attempts to update an edge's policy within the relevant
|
||||||
// buckets using an existing database transaction. The returned boolean will be
|
// buckets using an existing database transaction. The returned boolean will be
|
||||||
// true if the updated policy belongs to node1, and false if the policy belonged
|
// true if the updated policy belongs to node1, and false if the policy belonged
|
||||||
@ -2083,297 +391,6 @@ func (l *LightningNode) PubKey() (*btcec.PublicKey, error) {
|
|||||||
return key, nil
|
return key, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthSig is a signature under the advertised public key which serves to
|
|
||||||
// authenticate the attributes announced by this node.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the signature if absolutely necessary.
|
|
||||||
func (l *LightningNode) AuthSig() (*btcec.Signature, error) {
|
|
||||||
return btcec.ParseSignature(l.AuthSigBytes, btcec.S256())
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPubKey is a setter-link method that can be used to swap out the public
|
|
||||||
// key for a node.
|
|
||||||
func (l *LightningNode) AddPubKey(key *btcec.PublicKey) {
|
|
||||||
l.pubKey = key
|
|
||||||
copy(l.PubKeyBytes[:], key.SerializeCompressed())
|
|
||||||
}
|
|
||||||
|
|
||||||
// NodeAnnouncement retrieves the latest node announcement of the node.
|
|
||||||
func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement,
|
|
||||||
error) {
|
|
||||||
|
|
||||||
if !l.HaveNodeAnnouncement {
|
|
||||||
return nil, fmt.Errorf("node does not have node announcement")
|
|
||||||
}
|
|
||||||
|
|
||||||
alias, err := lnwire.NewNodeAlias(l.Alias)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeAnn := &lnwire.NodeAnnouncement{
|
|
||||||
Features: l.Features.RawFeatureVector,
|
|
||||||
NodeID: l.PubKeyBytes,
|
|
||||||
RGBColor: l.Color,
|
|
||||||
Alias: alias,
|
|
||||||
Addresses: l.Addresses,
|
|
||||||
Timestamp: uint32(l.LastUpdate.Unix()),
|
|
||||||
ExtraOpaqueData: l.ExtraOpaqueData,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !signed {
|
|
||||||
return nodeAnn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := lnwire.NewSigFromRawSignature(l.AuthSigBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeAnn.Signature = sig
|
|
||||||
|
|
||||||
return nodeAnn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isPublic determines whether the node is seen as public within the graph from
|
|
||||||
// the source node's point of view. An existing database transaction can also be
|
|
||||||
// specified.
|
|
||||||
func (l *LightningNode) isPublic(tx *bbolt.Tx, sourcePubKey []byte) (bool, error) {
|
|
||||||
// In order to determine whether this node is publicly advertised within
|
|
||||||
// the graph, we'll need to look at all of its edges and check whether
|
|
||||||
// they extend to any other node than the source node. errDone will be
|
|
||||||
// used to terminate the check early.
|
|
||||||
nodeIsPublic := false
|
|
||||||
errDone := errors.New("done")
|
|
||||||
err := l.ForEachChannel(tx, func(_ *bbolt.Tx, info *ChannelEdgeInfo,
|
|
||||||
_, _ *ChannelEdgePolicy) error {
|
|
||||||
|
|
||||||
// If this edge doesn't extend to the source node, we'll
|
|
||||||
// terminate our search as we can now conclude that the node is
|
|
||||||
// publicly advertised within the graph due to the local node
|
|
||||||
// knowing of the current edge.
|
|
||||||
if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) &&
|
|
||||||
!bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) {
|
|
||||||
|
|
||||||
nodeIsPublic = true
|
|
||||||
return errDone
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since the edge _does_ extend to the source node, we'll also
|
|
||||||
// need to ensure that this is a public edge.
|
|
||||||
if info.AuthProof != nil {
|
|
||||||
nodeIsPublic = true
|
|
||||||
return errDone
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we'll continue our search.
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil && err != errDone {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodeIsPublic, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchLightningNode attempts to look up a target node by its identity public
|
|
||||||
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
|
|
||||||
// returned.
|
|
||||||
func (c *ChannelGraph) FetchLightningNode(pub *btcec.PublicKey) (*LightningNode, error) {
|
|
||||||
var node *LightningNode
|
|
||||||
nodePub := pub.SerializeCompressed()
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// First grab the nodes bucket which stores the mapping from
|
|
||||||
// pubKey to node information.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a key for this serialized public key isn't found, then
|
|
||||||
// the target node doesn't exist within the database.
|
|
||||||
nodeBytes := nodes.Get(nodePub)
|
|
||||||
if nodeBytes == nil {
|
|
||||||
return ErrGraphNodeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the node is found, then we can de deserialize the node
|
|
||||||
// information to return to the user.
|
|
||||||
nodeReader := bytes.NewReader(nodeBytes)
|
|
||||||
n, err := deserializeLightningNode(nodeReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
n.db = c.db
|
|
||||||
|
|
||||||
node = &n
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return node, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasLightningNode determines if the graph has a vertex identified by the
|
|
||||||
// target node identity public key. If the node exists in the database, a
|
|
||||||
// timestamp of when the data for the node was lasted updated is returned along
|
|
||||||
// with a true boolean. Otherwise, an empty time.Time is returned with a false
|
|
||||||
// boolean.
|
|
||||||
func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, error) {
|
|
||||||
var (
|
|
||||||
updateTime time.Time
|
|
||||||
exists bool
|
|
||||||
)
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// First grab the nodes bucket which stores the mapping from
|
|
||||||
// pubKey to node information.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a key for this serialized public key isn't found, we can
|
|
||||||
// exit early.
|
|
||||||
nodeBytes := nodes.Get(nodePub[:])
|
|
||||||
if nodeBytes == nil {
|
|
||||||
exists = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise we continue on to obtain the time stamp
|
|
||||||
// representing the last time the data for this node was
|
|
||||||
// updated.
|
|
||||||
nodeReader := bytes.NewReader(nodeBytes)
|
|
||||||
node, err := deserializeLightningNode(nodeReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
exists = true
|
|
||||||
updateTime = node.LastUpdate
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return time.Time{}, exists, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return updateTime, exists, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// nodeTraversal is used to traverse all channels of a node given by its
|
|
||||||
// public key and passes channel information into the specified callback.
|
|
||||||
func nodeTraversal(tx *bbolt.Tx, nodePub []byte, db *DB,
|
|
||||||
cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
|
|
||||||
|
|
||||||
traversal := func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// In order to reach all the edges for this node, we take
|
|
||||||
// advantage of the construction of the key-space within the
|
|
||||||
// edge bucket. The keys are stored in the form: pubKey ||
|
|
||||||
// chanID. Therefore, starting from a chanID of zero, we can
|
|
||||||
// scan forward in the bucket, grabbing all the edges for the
|
|
||||||
// node. Once the prefix no longer matches, then we know we're
|
|
||||||
// done.
|
|
||||||
var nodeStart [33 + 8]byte
|
|
||||||
copy(nodeStart[:], nodePub)
|
|
||||||
copy(nodeStart[33:], chanStart[:])
|
|
||||||
|
|
||||||
// Starting from the key pubKey || 0, we seek forward in the
|
|
||||||
// bucket until the retrieved key no longer has the public key
|
|
||||||
// as its prefix. This indicates that we've stepped over into
|
|
||||||
// another node's edges, so we can terminate our scan.
|
|
||||||
edgeCursor := edges.Cursor()
|
|
||||||
for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() {
|
|
||||||
// If the prefix still matches, the channel id is
|
|
||||||
// returned in nodeEdge. Channel id is used to lookup
|
|
||||||
// the node at the other end of the channel and both
|
|
||||||
// edge policies.
|
|
||||||
chanID := nodeEdge[33:]
|
|
||||||
edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edgeInfo.db = db
|
|
||||||
|
|
||||||
outgoingPolicy, err := fetchChanEdgePolicy(
|
|
||||||
edges, chanID, nodePub, nodes,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
incomingPolicy, err := fetchChanEdgePolicy(
|
|
||||||
edges, chanID, otherNode[:], nodes,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we execute the callback.
|
|
||||||
err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no transaction was provided, then we'll create a new transaction
|
|
||||||
// to execute the transaction within.
|
|
||||||
if tx == nil {
|
|
||||||
return db.View(traversal)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we re-use the existing transaction to execute the graph
|
|
||||||
// traversal.
|
|
||||||
return traversal(tx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForEachChannel iterates through all channels of this node, executing the
|
|
||||||
// passed callback with an edge info structure and the policies of each end
|
|
||||||
// of the channel. The first edge policy is the outgoing edge *to* the
|
|
||||||
// the connecting node, while the second is the incoming edge *from* the
|
|
||||||
// connecting node. If the callback returns an error, then the iteration is
|
|
||||||
// halted with the error propagated back up to the caller.
|
|
||||||
//
|
|
||||||
// Unknown policies are passed into the callback as nil values.
|
|
||||||
//
|
|
||||||
// If the caller wishes to re-use an existing boltdb transaction, then it
|
|
||||||
// should be passed as the first argument. Otherwise the first argument should
|
|
||||||
// be nil and a fresh transaction will be created to execute the graph
|
|
||||||
// traversal.
|
|
||||||
func (l *LightningNode) ForEachChannel(tx *bbolt.Tx,
|
|
||||||
cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error {
|
|
||||||
|
|
||||||
nodePub := l.PubKeyBytes[:]
|
|
||||||
db := l.db
|
|
||||||
|
|
||||||
return nodeTraversal(tx, nodePub, db, cb)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelEdgeInfo represents a fully authenticated channel along with all its
|
// ChannelEdgeInfo represents a fully authenticated channel along with all its
|
||||||
// unique attributes. Once an authenticated channel announcement has been
|
// unique attributes. Once an authenticated channel announcement has been
|
||||||
// processed on the network, then an instance of ChannelEdgeInfo encapsulating
|
// processed on the network, then an instance of ChannelEdgeInfo encapsulating
|
||||||
@ -2395,19 +412,15 @@ type ChannelEdgeInfo struct {
|
|||||||
|
|
||||||
// NodeKey1Bytes is the raw public key of the first node.
|
// NodeKey1Bytes is the raw public key of the first node.
|
||||||
NodeKey1Bytes [33]byte
|
NodeKey1Bytes [33]byte
|
||||||
nodeKey1 *btcec.PublicKey
|
|
||||||
|
|
||||||
// NodeKey2Bytes is the raw public key of the first node.
|
// NodeKey2Bytes is the raw public key of the first node.
|
||||||
NodeKey2Bytes [33]byte
|
NodeKey2Bytes [33]byte
|
||||||
nodeKey2 *btcec.PublicKey
|
|
||||||
|
|
||||||
// BitcoinKey1Bytes is the raw public key of the first node.
|
// BitcoinKey1Bytes is the raw public key of the first node.
|
||||||
BitcoinKey1Bytes [33]byte
|
BitcoinKey1Bytes [33]byte
|
||||||
bitcoinKey1 *btcec.PublicKey
|
|
||||||
|
|
||||||
// BitcoinKey2Bytes is the raw public key of the first node.
|
// BitcoinKey2Bytes is the raw public key of the first node.
|
||||||
BitcoinKey2Bytes [33]byte
|
BitcoinKey2Bytes [33]byte
|
||||||
bitcoinKey2 *btcec.PublicKey
|
|
||||||
|
|
||||||
// Features is an opaque byte slice that encodes the set of channel
|
// Features is an opaque byte slice that encodes the set of channel
|
||||||
// specific features that this channel edge supports.
|
// specific features that this channel edge supports.
|
||||||
@ -2433,173 +446,6 @@ type ChannelEdgeInfo struct {
|
|||||||
// and ensure we're able to make upgrades to the network in a forwards
|
// and ensure we're able to make upgrades to the network in a forwards
|
||||||
// compatible manner.
|
// compatible manner.
|
||||||
ExtraOpaqueData []byte
|
ExtraOpaqueData []byte
|
||||||
|
|
||||||
db *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddNodeKeys is a setter-like method that can be used to replace the set of
|
|
||||||
// keys for the target ChannelEdgeInfo.
|
|
||||||
func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1,
|
|
||||||
bitcoinKey2 *btcec.PublicKey) {
|
|
||||||
|
|
||||||
c.nodeKey1 = nodeKey1
|
|
||||||
copy(c.NodeKey1Bytes[:], c.nodeKey1.SerializeCompressed())
|
|
||||||
|
|
||||||
c.nodeKey2 = nodeKey2
|
|
||||||
copy(c.NodeKey2Bytes[:], nodeKey2.SerializeCompressed())
|
|
||||||
|
|
||||||
c.bitcoinKey1 = bitcoinKey1
|
|
||||||
copy(c.BitcoinKey1Bytes[:], c.bitcoinKey1.SerializeCompressed())
|
|
||||||
|
|
||||||
c.bitcoinKey2 = bitcoinKey2
|
|
||||||
copy(c.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed())
|
|
||||||
}
|
|
||||||
|
|
||||||
// NodeKey1 is the identity public key of the "first" node that was involved in
|
|
||||||
// the creation of this channel. A node is considered "first" if the
|
|
||||||
// lexicographical ordering the its serialized public key is "smaller" than
|
|
||||||
// that of the other node involved in channel creation.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the pubkey if absolutely necessary.
|
|
||||||
func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) {
|
|
||||||
if c.nodeKey1 != nil {
|
|
||||||
return c.nodeKey1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := btcec.ParsePubKey(c.NodeKey1Bytes[:], btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.nodeKey1 = key
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NodeKey2 is the identity public key of the "second" node that was
|
|
||||||
// involved in the creation of this channel. A node is considered
|
|
||||||
// "second" if the lexicographical ordering the its serialized public
|
|
||||||
// key is "larger" than that of the other node involved in channel
|
|
||||||
// creation.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the pubkey if absolutely necessary.
|
|
||||||
func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) {
|
|
||||||
if c.nodeKey2 != nil {
|
|
||||||
return c.nodeKey2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := btcec.ParsePubKey(c.NodeKey2Bytes[:], btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.nodeKey2 = key
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BitcoinKey1 is the Bitcoin multi-sig key belonging to the first
|
|
||||||
// node, that was involved in the funding transaction that originally
|
|
||||||
// created the channel that this struct represents.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the pubkey if absolutely necessary.
|
|
||||||
func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) {
|
|
||||||
if c.bitcoinKey1 != nil {
|
|
||||||
return c.bitcoinKey1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:], btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.bitcoinKey1 = key
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BitcoinKey2 is the Bitcoin multi-sig key belonging to the second
|
|
||||||
// node, that was involved in the funding transaction that originally
|
|
||||||
// created the channel that this struct represents.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the pubkey if absolutely necessary.
|
|
||||||
func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) {
|
|
||||||
if c.bitcoinKey2 != nil {
|
|
||||||
return c.bitcoinKey2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:], btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.bitcoinKey2 = key
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// OtherNodeKeyBytes returns the node key bytes of the other end of
|
|
||||||
// the channel.
|
|
||||||
func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) (
|
|
||||||
[33]byte, error) {
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey):
|
|
||||||
return c.NodeKey2Bytes, nil
|
|
||||||
case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey):
|
|
||||||
return c.NodeKey1Bytes, nil
|
|
||||||
default:
|
|
||||||
return [33]byte{}, fmt.Errorf("node not participating in this channel")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchOtherNode attempts to fetch the full LightningNode that's opposite of
|
|
||||||
// the target node in the channel. This is useful when one knows the pubkey of
|
|
||||||
// one of the nodes, and wishes to obtain the full LightningNode for the other
|
|
||||||
// end of the channel.
|
|
||||||
func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*LightningNode, error) {
|
|
||||||
|
|
||||||
// Ensure that the node passed in is actually a member of the channel.
|
|
||||||
var targetNodeBytes [33]byte
|
|
||||||
switch {
|
|
||||||
case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey):
|
|
||||||
targetNodeBytes = c.NodeKey2Bytes
|
|
||||||
case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey):
|
|
||||||
targetNodeBytes = c.NodeKey1Bytes
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("node not participating in this channel")
|
|
||||||
}
|
|
||||||
|
|
||||||
var targetNode *LightningNode
|
|
||||||
fetchNodeFunc := func(tx *bbolt.Tx) error {
|
|
||||||
// First grab the nodes bucket which stores the mapping from
|
|
||||||
// pubKey to node information.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
node, err := fetchLightningNode(nodes, targetNodeBytes[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
node.db = c.db
|
|
||||||
|
|
||||||
targetNode = &node
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the transaction is nil, then we'll need to create a new one,
|
|
||||||
// otherwise we can use the existing db transaction.
|
|
||||||
var err error
|
|
||||||
if tx == nil {
|
|
||||||
err = c.db.View(fetchNodeFunc)
|
|
||||||
} else {
|
|
||||||
err = fetchNodeFunc(tx)
|
|
||||||
}
|
|
||||||
|
|
||||||
return targetNode, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChannelAuthProof is the authentication proof (the signature portion) for a
|
// ChannelAuthProof is the authentication proof (the signature portion) for a
|
||||||
@ -2610,117 +456,23 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*Lig
|
|||||||
// nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len ||
|
// nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len ||
|
||||||
// features.
|
// features.
|
||||||
type ChannelAuthProof struct {
|
type ChannelAuthProof struct {
|
||||||
// nodeSig1 is a cached instance of the first node signature.
|
|
||||||
nodeSig1 *btcec.Signature
|
|
||||||
|
|
||||||
// NodeSig1Bytes are the raw bytes of the first node signature encoded
|
// NodeSig1Bytes are the raw bytes of the first node signature encoded
|
||||||
// in DER format.
|
// in DER format.
|
||||||
NodeSig1Bytes []byte
|
NodeSig1Bytes []byte
|
||||||
|
|
||||||
// nodeSig2 is a cached instance of the second node signature.
|
|
||||||
nodeSig2 *btcec.Signature
|
|
||||||
|
|
||||||
// NodeSig2Bytes are the raw bytes of the second node signature
|
// NodeSig2Bytes are the raw bytes of the second node signature
|
||||||
// encoded in DER format.
|
// encoded in DER format.
|
||||||
NodeSig2Bytes []byte
|
NodeSig2Bytes []byte
|
||||||
|
|
||||||
// bitcoinSig1 is a cached instance of the first bitcoin signature.
|
|
||||||
bitcoinSig1 *btcec.Signature
|
|
||||||
|
|
||||||
// BitcoinSig1Bytes are the raw bytes of the first bitcoin signature
|
// BitcoinSig1Bytes are the raw bytes of the first bitcoin signature
|
||||||
// encoded in DER format.
|
// encoded in DER format.
|
||||||
BitcoinSig1Bytes []byte
|
BitcoinSig1Bytes []byte
|
||||||
|
|
||||||
// bitcoinSig2 is a cached instance of the second bitcoin signature.
|
|
||||||
bitcoinSig2 *btcec.Signature
|
|
||||||
|
|
||||||
// BitcoinSig2Bytes are the raw bytes of the second bitcoin signature
|
// BitcoinSig2Bytes are the raw bytes of the second bitcoin signature
|
||||||
// encoded in DER format.
|
// encoded in DER format.
|
||||||
BitcoinSig2Bytes []byte
|
BitcoinSig2Bytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// Node1Sig is the signature using the identity key of the node that is first
|
|
||||||
// in a lexicographical ordering of the serialized public keys of the two nodes
|
|
||||||
// that created the channel.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the signature if absolutely necessary.
|
|
||||||
func (c *ChannelAuthProof) Node1Sig() (*btcec.Signature, error) {
|
|
||||||
if c.nodeSig1 != nil {
|
|
||||||
return c.nodeSig1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := btcec.ParseSignature(c.NodeSig1Bytes, btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.nodeSig1 = sig
|
|
||||||
|
|
||||||
return sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Node2Sig is the signature using the identity key of the node that is second
|
|
||||||
// in a lexicographical ordering of the serialized public keys of the two nodes
|
|
||||||
// that created the channel.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the signature if absolutely necessary.
|
|
||||||
func (c *ChannelAuthProof) Node2Sig() (*btcec.Signature, error) {
|
|
||||||
if c.nodeSig2 != nil {
|
|
||||||
return c.nodeSig2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := btcec.ParseSignature(c.NodeSig2Bytes, btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.nodeSig2 = sig
|
|
||||||
|
|
||||||
return sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BitcoinSig1 is the signature using the public key of the first node that was
|
|
||||||
// used in the channel's multi-sig output.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the signature if absolutely necessary.
|
|
||||||
func (c *ChannelAuthProof) BitcoinSig1() (*btcec.Signature, error) {
|
|
||||||
if c.bitcoinSig1 != nil {
|
|
||||||
return c.bitcoinSig1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := btcec.ParseSignature(c.BitcoinSig1Bytes, btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.bitcoinSig1 = sig
|
|
||||||
|
|
||||||
return sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BitcoinSig2 is the signature using the public key of the second node that
|
|
||||||
// was used in the channel's multi-sig output.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the signature if absolutely necessary.
|
|
||||||
func (c *ChannelAuthProof) BitcoinSig2() (*btcec.Signature, error) {
|
|
||||||
if c.bitcoinSig2 != nil {
|
|
||||||
return c.bitcoinSig2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := btcec.ParseSignature(c.BitcoinSig2Bytes, btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.bitcoinSig2 = sig
|
|
||||||
|
|
||||||
return sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEmpty check is the authentication proof is empty Proof is empty if at
|
// IsEmpty check is the authentication proof is empty Proof is empty if at
|
||||||
// least one of the signatures are equal to nil.
|
// least one of the signatures are equal to nil.
|
||||||
func (c *ChannelAuthProof) IsEmpty() bool {
|
func (c *ChannelAuthProof) IsEmpty() bool {
|
||||||
@ -2742,9 +494,6 @@ type ChannelEdgePolicy struct {
|
|||||||
// use SetSigBytes instead to make sure that the cache is invalidated.
|
// use SetSigBytes instead to make sure that the cache is invalidated.
|
||||||
SigBytes []byte
|
SigBytes []byte
|
||||||
|
|
||||||
// sig is a cached fully parsed signature.
|
|
||||||
sig *btcec.Signature
|
|
||||||
|
|
||||||
// ChannelID is the unique channel ID for the channel. The first 3
|
// ChannelID is the unique channel ID for the channel. The first 3
|
||||||
// bytes are the block height, the next 3 the index within the block,
|
// bytes are the block height, the next 3 the index within the block,
|
||||||
// and the last 2 bytes are the output index for the channel.
|
// and the last 2 bytes are the output index for the channel.
|
||||||
@ -2794,35 +543,6 @@ type ChannelEdgePolicy struct {
|
|||||||
// and ensure we're able to make upgrades to the network in a forwards
|
// and ensure we're able to make upgrades to the network in a forwards
|
||||||
// compatible manner.
|
// compatible manner.
|
||||||
ExtraOpaqueData []byte
|
ExtraOpaqueData []byte
|
||||||
|
|
||||||
db *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// Signature is a channel announcement signature, which is needed for proper
|
|
||||||
// edge policy announcement.
|
|
||||||
//
|
|
||||||
// NOTE: By having this method to access an attribute, we ensure we only need
|
|
||||||
// to fully deserialize the signature if absolutely necessary.
|
|
||||||
func (c *ChannelEdgePolicy) Signature() (*btcec.Signature, error) {
|
|
||||||
if c.sig != nil {
|
|
||||||
return c.sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sig, err := btcec.ParseSignature(c.SigBytes, btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.sig = sig
|
|
||||||
|
|
||||||
return sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSigBytes updates the signature and invalidates the cached parsed
|
|
||||||
// signature.
|
|
||||||
func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) {
|
|
||||||
c.SigBytes = sig
|
|
||||||
c.sig = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDisabled determines whether the edge has the disabled bit set.
|
// IsDisabled determines whether the edge has the disabled bit set.
|
||||||
@ -2831,488 +551,6 @@ func (c *ChannelEdgePolicy) IsDisabled() bool {
|
|||||||
lnwire.ChanUpdateDisabled
|
lnwire.ChanUpdateDisabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over
|
|
||||||
// the passed active payment channel. This value is currently computed as
|
|
||||||
// specified in BOLT07, but will likely change in the near future.
|
|
||||||
func (c *ChannelEdgePolicy) ComputeFee(
|
|
||||||
amt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
|
|
||||||
|
|
||||||
return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts
|
|
||||||
}
|
|
||||||
|
|
||||||
// divideCeil divides dividend by factor and rounds the result up.
|
|
||||||
func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi {
|
|
||||||
return (dividend + factor - 1) / factor
|
|
||||||
}
|
|
||||||
|
|
||||||
// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming
|
|
||||||
// amount.
|
|
||||||
func (c *ChannelEdgePolicy) ComputeFeeFromIncoming(
|
|
||||||
incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi {
|
|
||||||
|
|
||||||
return incomingAmt - divideCeil(
|
|
||||||
feeRateParts*(incomingAmt-c.FeeBaseMSat),
|
|
||||||
feeRateParts+c.FeeProportionalMillionths,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for
|
|
||||||
// the channel identified by the funding outpoint. If the channel can't be
|
|
||||||
// found, then ErrEdgeNotFound is returned. A struct which houses the general
|
|
||||||
// information for the channel itself is returned as well as two structs that
|
|
||||||
// contain the routing policies for the channel in either direction.
|
|
||||||
func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint,
|
|
||||||
) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
edgeInfo *ChannelEdgeInfo
|
|
||||||
policy1 *ChannelEdgePolicy
|
|
||||||
policy2 *ChannelEdgePolicy
|
|
||||||
)
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// First, grab the node bucket. This will be used to populate
|
|
||||||
// the Node pointers in each edge read from disk.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, grab the edge bucket which stores the edges, and also
|
|
||||||
// the index itself so we can group the directed edges together
|
|
||||||
// logically.
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel's outpoint doesn't exist within the outpoint
|
|
||||||
// index, then the edge does not exist.
|
|
||||||
chanIndex := edges.Bucket(channelPointBucket)
|
|
||||||
if chanIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := writeOutpoint(&b, op); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
chanID := chanIndex.Get(b.Bytes())
|
|
||||||
if chanID == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel is found to exists, then we'll first retrieve
|
|
||||||
// the general information for the channel.
|
|
||||||
edge, err := fetchChanEdgeInfo(edgeIndex, chanID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
edgeInfo = &edge
|
|
||||||
edgeInfo.db = c.db
|
|
||||||
|
|
||||||
// Once we have the information about the channels' parameters,
|
|
||||||
// we'll fetch the routing policies for each for the directed
|
|
||||||
// edges.
|
|
||||||
e1, e2, err := fetchChanEdgePolicies(
|
|
||||||
edgeIndex, edges, nodes, chanID, c.db,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
policy1 = e1
|
|
||||||
policy2 = e2
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return edgeInfo, policy1, policy2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchChannelEdgesByID attempts to lookup the two directed edges for the
|
|
||||||
// channel identified by the channel ID. If the channel can't be found, then
|
|
||||||
// ErrEdgeNotFound is returned. A struct which houses the general information
|
|
||||||
// for the channel itself is returned as well as two structs that contain the
|
|
||||||
// routing policies for the channel in either direction.
|
|
||||||
//
|
|
||||||
// ErrZombieEdge an be returned if the edge is currently marked as a zombie
|
|
||||||
// within the database. In this case, the ChannelEdgePolicy's will be nil, and
|
|
||||||
// the ChannelEdgeInfo will only include the public keys of each node.
|
|
||||||
func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64,
|
|
||||||
) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
edgeInfo *ChannelEdgeInfo
|
|
||||||
policy1 *ChannelEdgePolicy
|
|
||||||
policy2 *ChannelEdgePolicy
|
|
||||||
channelID [8]byte
|
|
||||||
)
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// First, grab the node bucket. This will be used to populate
|
|
||||||
// the Node pointers in each edge read from disk.
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, grab the edge bucket which stores the edges, and also
|
|
||||||
// the index itself so we can group the directed edges together
|
|
||||||
// logically.
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
byteOrder.PutUint64(channelID[:], chanID)
|
|
||||||
|
|
||||||
// Now, attempt to fetch edge.
|
|
||||||
edge, err := fetchChanEdgeInfo(edgeIndex, channelID[:])
|
|
||||||
|
|
||||||
// If it doesn't exist, we'll quickly check our zombie index to
|
|
||||||
// see if we've previously marked it as so.
|
|
||||||
if err == ErrEdgeNotFound {
|
|
||||||
// If the zombie index doesn't exist, or the edge is not
|
|
||||||
// marked as a zombie within it, then we'll return the
|
|
||||||
// original ErrEdgeNotFound error.
|
|
||||||
zombieIndex := edges.Bucket(zombieBucket)
|
|
||||||
if zombieIndex == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
isZombie, pubKey1, pubKey2 := isZombieEdge(
|
|
||||||
zombieIndex, chanID,
|
|
||||||
)
|
|
||||||
if !isZombie {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, the edge is marked as a zombie, so we'll
|
|
||||||
// populate the edge info with the public keys of each
|
|
||||||
// party as this is the only information we have about
|
|
||||||
// it and return an error signaling so.
|
|
||||||
edgeInfo = &ChannelEdgeInfo{
|
|
||||||
NodeKey1Bytes: pubKey1,
|
|
||||||
NodeKey2Bytes: pubKey2,
|
|
||||||
}
|
|
||||||
return ErrZombieEdge
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, we'll just return the error if any.
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeInfo = &edge
|
|
||||||
edgeInfo.db = c.db
|
|
||||||
|
|
||||||
// Then we'll attempt to fetch the accompanying policies of this
|
|
||||||
// edge.
|
|
||||||
e1, e2, err := fetchChanEdgePolicies(
|
|
||||||
edgeIndex, edges, nodes, channelID[:], c.db,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
policy1 = e1
|
|
||||||
policy2 = e2
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err == ErrZombieEdge {
|
|
||||||
return edgeInfo, nil, nil, err
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return edgeInfo, policy1, policy2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsPublicNode is a helper method that determines whether the node with the
|
|
||||||
// given public key is seen as a public node in the graph from the graph's
|
|
||||||
// source node's point of view.
|
|
||||||
func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) {
|
|
||||||
var nodeIsPublic bool
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNodesNotFound
|
|
||||||
}
|
|
||||||
ourPubKey := nodes.Get(sourceKey)
|
|
||||||
if ourPubKey == nil {
|
|
||||||
return ErrSourceNodeNotSet
|
|
||||||
}
|
|
||||||
node, err := fetchLightningNode(nodes, pubKey[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeIsPublic, err = node.isPublic(tx, ourPubKey)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodeIsPublic, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys.
|
|
||||||
func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) {
|
|
||||||
if len(aPub) != 33 || len(bPub) != 33 {
|
|
||||||
return nil, fmt.Errorf("Pubkey size error. Compressed " +
|
|
||||||
"pubkeys only")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Swap to sort pubkeys if needed. Keys are sorted in lexicographical
|
|
||||||
// order. The signatures within the scriptSig must also adhere to the
|
|
||||||
// order, ensuring that the signatures for each public key appears in
|
|
||||||
// the proper order on the stack.
|
|
||||||
if bytes.Compare(aPub, bPub) == 1 {
|
|
||||||
aPub, bPub = bPub, aPub
|
|
||||||
}
|
|
||||||
|
|
||||||
// First, we'll generate the witness script for the multi-sig.
|
|
||||||
bldr := txscript.NewScriptBuilder()
|
|
||||||
bldr.AddOp(txscript.OP_2)
|
|
||||||
bldr.AddData(aPub) // Add both pubkeys (sorted).
|
|
||||||
bldr.AddData(bPub)
|
|
||||||
bldr.AddOp(txscript.OP_2)
|
|
||||||
bldr.AddOp(txscript.OP_CHECKMULTISIG)
|
|
||||||
witnessScript, err := bldr.Script()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the witness script generated, we'll now turn it into a p2sh
|
|
||||||
// script:
|
|
||||||
// * OP_0 <sha256(script)>
|
|
||||||
bldr = txscript.NewScriptBuilder()
|
|
||||||
bldr.AddOp(txscript.OP_0)
|
|
||||||
scriptHash := sha256.Sum256(witnessScript)
|
|
||||||
bldr.AddData(scriptHash[:])
|
|
||||||
|
|
||||||
return bldr.Script()
|
|
||||||
}
|
|
||||||
|
|
||||||
// EdgePoint couples the outpoint of a channel with the funding script that it
|
|
||||||
// creates. The FilteredChainView will use this to watch for spends of this
|
|
||||||
// edge point on chain. We require both of these values as depending on the
|
|
||||||
// concrete implementation, either the pkScript, or the out point will be used.
|
|
||||||
type EdgePoint struct {
|
|
||||||
// FundingPkScript is the p2wsh multi-sig script of the target channel.
|
|
||||||
FundingPkScript []byte
|
|
||||||
|
|
||||||
// OutPoint is the outpoint of the target channel.
|
|
||||||
OutPoint wire.OutPoint
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a human readable version of the target EdgePoint. We return
|
|
||||||
// the outpoint directly as it is enough to uniquely identify the edge point.
|
|
||||||
func (e *EdgePoint) String() string {
|
|
||||||
return e.OutPoint.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChannelView returns the verifiable edge information for each active channel
|
|
||||||
// within the known channel graph. The set of UTXO's (along with their scripts)
|
|
||||||
// returned are the ones that need to be watched on chain to detect channel
|
|
||||||
// closes on the resident blockchain.
|
|
||||||
func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) {
|
|
||||||
var edgePoints []EdgePoint
|
|
||||||
if err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
// We're going to iterate over the entire channel index, so
|
|
||||||
// we'll need to fetch the edgeBucket to get to the index as
|
|
||||||
// it's a sub-bucket.
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
chanIndex := edges.Bucket(channelPointBucket)
|
|
||||||
if chanIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once we have the proper bucket, we'll range over each key
|
|
||||||
// (which is the channel point for the channel) and decode it,
|
|
||||||
// accumulating each entry.
|
|
||||||
return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error {
|
|
||||||
chanPointReader := bytes.NewReader(chanPointBytes)
|
|
||||||
|
|
||||||
var chanPoint wire.OutPoint
|
|
||||||
err := readOutpoint(chanPointReader, &chanPoint)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeInfo, err := fetchChanEdgeInfo(
|
|
||||||
edgeIndex, chanID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
pkScript, err := genMultiSigP2WSH(
|
|
||||||
edgeInfo.BitcoinKey1Bytes[:],
|
|
||||||
edgeInfo.BitcoinKey2Bytes[:],
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
edgePoints = append(edgePoints, EdgePoint{
|
|
||||||
FundingPkScript: pkScript,
|
|
||||||
OutPoint: chanPoint,
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return edgePoints, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewChannelEdgePolicy returns a new blank ChannelEdgePolicy.
|
|
||||||
func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy {
|
|
||||||
return &ChannelEdgePolicy{db: c.db}
|
|
||||||
}
|
|
||||||
|
|
||||||
// markEdgeZombie marks an edge as a zombie within our zombie index. The public
|
|
||||||
// keys should represent the node public keys of the two parties involved in the
|
|
||||||
// edge.
|
|
||||||
func markEdgeZombie(zombieIndex *bbolt.Bucket, chanID uint64, pubKey1,
|
|
||||||
pubKey2 [33]byte) error {
|
|
||||||
|
|
||||||
var k [8]byte
|
|
||||||
byteOrder.PutUint64(k[:], chanID)
|
|
||||||
|
|
||||||
var v [66]byte
|
|
||||||
copy(v[:33], pubKey1[:])
|
|
||||||
copy(v[33:], pubKey2[:])
|
|
||||||
|
|
||||||
return zombieIndex.Put(k[:], v[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkEdgeLive clears an edge from our zombie index, deeming it as live.
|
|
||||||
func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error {
|
|
||||||
c.cacheMu.Lock()
|
|
||||||
defer c.cacheMu.Unlock()
|
|
||||||
|
|
||||||
err := c.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
zombieIndex := edges.Bucket(zombieBucket)
|
|
||||||
if zombieIndex == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var k [8]byte
|
|
||||||
byteOrder.PutUint64(k[:], chanID)
|
|
||||||
return zombieIndex.Delete(k[:])
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.rejectCache.remove(chanID)
|
|
||||||
c.chanCache.remove(chanID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsZombieEdge returns whether the edge is considered zombie. If it is a
|
|
||||||
// zombie, then the two node public keys corresponding to this edge are also
|
|
||||||
// returned.
|
|
||||||
func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) {
|
|
||||||
var (
|
|
||||||
isZombie bool
|
|
||||||
pubKey1, pubKey2 [33]byte
|
|
||||||
)
|
|
||||||
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
zombieIndex := edges.Bucket(zombieBucket)
|
|
||||||
if zombieIndex == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return false, [33]byte{}, [33]byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return isZombie, pubKey1, pubKey2
|
|
||||||
}
|
|
||||||
|
|
||||||
// isZombieEdge returns whether an entry exists for the given channel in the
|
|
||||||
// zombie index. If an entry exists, then the two node public keys corresponding
|
|
||||||
// to this edge are also returned.
|
|
||||||
func isZombieEdge(zombieIndex *bbolt.Bucket,
|
|
||||||
chanID uint64) (bool, [33]byte, [33]byte) {
|
|
||||||
|
|
||||||
var k [8]byte
|
|
||||||
byteOrder.PutUint64(k[:], chanID)
|
|
||||||
|
|
||||||
v := zombieIndex.Get(k[:])
|
|
||||||
if v == nil {
|
|
||||||
return false, [33]byte{}, [33]byte{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var pubKey1, pubKey2 [33]byte
|
|
||||||
copy(pubKey1[:], v[:33])
|
|
||||||
copy(pubKey2[:], v[33:])
|
|
||||||
|
|
||||||
return true, pubKey1, pubKey2
|
|
||||||
}
|
|
||||||
|
|
||||||
// NumZombies returns the current number of zombie channels in the graph.
|
|
||||||
func (c *ChannelGraph) NumZombies() (uint64, error) {
|
|
||||||
var numZombies uint64
|
|
||||||
err := c.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
zombieIndex := edges.Bucket(zombieBucket)
|
|
||||||
if zombieIndex == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return zombieIndex.ForEach(func(_, _ []byte) error {
|
|
||||||
numZombies++
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return numZombies, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func putLightningNode(nodeBucket *bbolt.Bucket, aliasBucket *bbolt.Bucket,
|
func putLightningNode(nodeBucket *bbolt.Bucket, aliasBucket *bbolt.Bucket,
|
||||||
updateIndex *bbolt.Bucket, node *LightningNode) error {
|
updateIndex *bbolt.Bucket, node *LightningNode) error {
|
||||||
|
|
||||||
@ -3548,84 +786,6 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) {
|
|||||||
return node, nil
|
return node, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func putChanEdgeInfo(edgeIndex *bbolt.Bucket, edgeInfo *ChannelEdgeInfo, chanID [8]byte) error {
|
|
||||||
var b bytes.Buffer
|
|
||||||
|
|
||||||
if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
authProof := edgeInfo.AuthProof
|
|
||||||
var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte
|
|
||||||
if authProof != nil {
|
|
||||||
nodeSig1 = authProof.NodeSig1Bytes
|
|
||||||
nodeSig2 = authProof.NodeSig2Bytes
|
|
||||||
bitcoinSig1 = authProof.BitcoinSig1Bytes
|
|
||||||
bitcoinSig2 = authProof.BitcoinSig2Bytes
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := b.Write(chanID[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
|
|
||||||
return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData))
|
|
||||||
}
|
|
||||||
err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return edgeIndex.Put(chanID[:], b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchChanEdgeInfo(edgeIndex *bbolt.Bucket,
|
|
||||||
chanID []byte) (ChannelEdgeInfo, error) {
|
|
||||||
|
|
||||||
edgeInfoBytes := edgeIndex.Get(chanID)
|
|
||||||
if edgeInfoBytes == nil {
|
|
||||||
return ChannelEdgeInfo{}, ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeInfoReader := bytes.NewReader(edgeInfoBytes)
|
|
||||||
return deserializeChanEdgeInfo(edgeInfoReader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) {
|
func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
@ -3856,47 +1016,6 @@ func fetchChanEdgePolicy(edges *bbolt.Bucket, chanID []byte,
|
|||||||
return ep, nil
|
return ep, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchChanEdgePolicies(edgeIndex *bbolt.Bucket, edges *bbolt.Bucket,
|
|
||||||
nodes *bbolt.Bucket, chanID []byte,
|
|
||||||
db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) {
|
|
||||||
|
|
||||||
edgeInfo := edgeIndex.Get(chanID)
|
|
||||||
if edgeInfo == nil {
|
|
||||||
return nil, nil, ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// The first node is contained within the first half of the edge
|
|
||||||
// information. We only propagate the error here and below if it's
|
|
||||||
// something other than edge non-existence.
|
|
||||||
node1Pub := edgeInfo[:33]
|
|
||||||
edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// As we may have a single direction of the edge but not the other,
|
|
||||||
// only fill in the database pointers if the edge is found.
|
|
||||||
if edge1 != nil {
|
|
||||||
edge1.db = db
|
|
||||||
edge1.Node.db = db
|
|
||||||
}
|
|
||||||
|
|
||||||
// Similarly, the second node is contained within the latter
|
|
||||||
// half of the edge information.
|
|
||||||
node2Pub := edgeInfo[33:66]
|
|
||||||
edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if edge2 != nil {
|
|
||||||
edge2.db = db
|
|
||||||
edge2.Node.db = db
|
|
||||||
}
|
|
||||||
|
|
||||||
return edge1, edge2, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy,
|
func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy,
|
||||||
to []byte) error {
|
to []byte) error {
|
||||||
|
|
||||||
|
@ -1,24 +1,13 @@
|
|||||||
package migration_01_to_11
|
package migration_01_to_11
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/sha256"
|
|
||||||
"fmt"
|
|
||||||
"image/color"
|
"image/color"
|
||||||
"math"
|
|
||||||
"math/big"
|
"math/big"
|
||||||
prand "math/rand"
|
prand "math/rand"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,3132 +55,3 @@ func createTestVertex(db *DB) (*LightningNode, error) {
|
|||||||
|
|
||||||
return createLightningNode(db, priv)
|
return createLightningNode(db, priv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNodeInsertionAndDeletion(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test basic insertion/deletion for vertexes from the
|
|
||||||
// graph, so we'll create a test vertex to start with.
|
|
||||||
_, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
|
||||||
node := &LightningNode{
|
|
||||||
HaveNodeAnnouncement: true,
|
|
||||||
AuthSigBytes: testSig.Serialize(),
|
|
||||||
LastUpdate: time.Unix(1232342, 0),
|
|
||||||
Color: color.RGBA{1, 2, 3, 0},
|
|
||||||
Alias: "kek",
|
|
||||||
Features: testFeatures,
|
|
||||||
Addresses: testAddrs,
|
|
||||||
ExtraOpaqueData: []byte("extra new data"),
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
copy(node.PubKeyBytes[:], testPub.SerializeCompressed())
|
|
||||||
|
|
||||||
// First, insert the node into the graph DB. This should succeed
|
|
||||||
// without any errors.
|
|
||||||
if err := graph.AddLightningNode(node); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, fetch the node from the database to ensure everything was
|
|
||||||
// serialized properly.
|
|
||||||
dbNode, err := graph.FetchLightningNode(testPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to locate node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil {
|
|
||||||
t.Fatalf("unable to query for node: %v", err)
|
|
||||||
} else if !exists {
|
|
||||||
t.Fatalf("node should be found but wasn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The two nodes should match exactly!
|
|
||||||
if err := compareNodes(node, dbNode); err != nil {
|
|
||||||
t.Fatalf("nodes don't match: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, delete the node from the graph, this should purge all data
|
|
||||||
// related to the node.
|
|
||||||
if err := graph.DeleteLightningNode(testPub); err != nil {
|
|
||||||
t.Fatalf("unable to delete node; %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, attempt to fetch the node again. This should fail as the
|
|
||||||
// node should have been deleted from the database.
|
|
||||||
_, err = graph.FetchLightningNode(testPub)
|
|
||||||
if err != ErrGraphNodeNotFound {
|
|
||||||
t.Fatalf("fetch after delete should fail!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPartialNode checks that we can add and retrieve a LightningNode where
|
|
||||||
// where only the pubkey is known to the database.
|
|
||||||
func TestPartialNode(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We want to be able to insert nodes into the graph that only has the
|
|
||||||
// PubKey set.
|
|
||||||
_, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
|
||||||
node := &LightningNode{
|
|
||||||
HaveNodeAnnouncement: false,
|
|
||||||
}
|
|
||||||
copy(node.PubKeyBytes[:], testPub.SerializeCompressed())
|
|
||||||
|
|
||||||
if err := graph.AddLightningNode(node); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, fetch the node from the database to ensure everything was
|
|
||||||
// serialized properly.
|
|
||||||
dbNode, err := graph.FetchLightningNode(testPub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to locate node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil {
|
|
||||||
t.Fatalf("unable to query for node: %v", err)
|
|
||||||
} else if !exists {
|
|
||||||
t.Fatalf("node should be found but wasn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The two nodes should match exactly! (with default values for
|
|
||||||
// LastUpdate and db set to satisfy compareNodes())
|
|
||||||
node = &LightningNode{
|
|
||||||
HaveNodeAnnouncement: false,
|
|
||||||
LastUpdate: time.Unix(0, 0),
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
copy(node.PubKeyBytes[:], testPub.SerializeCompressed())
|
|
||||||
|
|
||||||
if err := compareNodes(node, dbNode); err != nil {
|
|
||||||
t.Fatalf("nodes don't match: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, delete the node from the graph, this should purge all data
|
|
||||||
// related to the node.
|
|
||||||
if err := graph.DeleteLightningNode(testPub); err != nil {
|
|
||||||
t.Fatalf("unable to delete node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, attempt to fetch the node again. This should fail as the
|
|
||||||
// node should have been deleted from the database.
|
|
||||||
_, err = graph.FetchLightningNode(testPub)
|
|
||||||
if err != ErrGraphNodeNotFound {
|
|
||||||
t.Fatalf("fetch after delete should fail!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAliasLookup(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test the alias index within the database, so first
|
|
||||||
// create a new test node.
|
|
||||||
testNode, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the node to the graph's database, this should also insert an
|
|
||||||
// entry into the alias index for this node.
|
|
||||||
if err := graph.AddLightningNode(testNode); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, attempt to lookup the alias. The alias should exactly match
|
|
||||||
// the one which the test node was assigned.
|
|
||||||
nodePub, err := testNode.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate pubkey: %v", err)
|
|
||||||
}
|
|
||||||
dbAlias, err := graph.LookupAlias(nodePub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to find alias: %v", err)
|
|
||||||
}
|
|
||||||
if dbAlias != testNode.Alias {
|
|
||||||
t.Fatalf("aliases don't match, expected %v got %v",
|
|
||||||
testNode.Alias, dbAlias)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that looking up a non-existent alias results in an error.
|
|
||||||
node, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
nodePub, err = node.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate pubkey: %v", err)
|
|
||||||
}
|
|
||||||
_, err = graph.LookupAlias(nodePub)
|
|
||||||
if err != ErrNodeAliasNotFound {
|
|
||||||
t.Fatalf("alias lookup should fail for non-existent pubkey")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSourceNode(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test the setting/getting of the source node, so we
|
|
||||||
// first create a fake node to use within the test.
|
|
||||||
testNode, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to fetch the source node, this should return an error as the
|
|
||||||
// source node hasn't yet been set.
|
|
||||||
if _, err := graph.SourceNode(); err != ErrSourceNodeNotSet {
|
|
||||||
t.Fatalf("source node shouldn't be set in new graph")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the source the source node, this should insert the node into the
|
|
||||||
// database in a special way indicating it's the source node.
|
|
||||||
if err := graph.SetSourceNode(testNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve the source node from the database, it should exactly match
|
|
||||||
// the one we set above.
|
|
||||||
sourceNode, err := graph.SourceNode()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch source node: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareNodes(testNode, sourceNode); err != nil {
|
|
||||||
t.Fatalf("nodes don't match: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEdgeInsertionDeletion(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test the insertion/deletion of edges, so we create two
|
|
||||||
// vertexes to connect.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// In addition to the fake vertexes we create some fake channel
|
|
||||||
// identifiers.
|
|
||||||
chanID := uint64(prand.Int63())
|
|
||||||
outpoint := wire.OutPoint{
|
|
||||||
Hash: rev,
|
|
||||||
Index: 9,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the new edge to the database, this should proceed without any
|
|
||||||
// errors.
|
|
||||||
node1Pub, err := node1.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate node key: %v", err)
|
|
||||||
}
|
|
||||||
node2Pub, err := node2.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate node key: %v", err)
|
|
||||||
}
|
|
||||||
edgeInfo := ChannelEdgeInfo{
|
|
||||||
ChannelID: chanID,
|
|
||||||
ChainHash: key,
|
|
||||||
AuthProof: &ChannelAuthProof{
|
|
||||||
NodeSig1Bytes: testSig.Serialize(),
|
|
||||||
NodeSig2Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig1Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig2Bytes: testSig.Serialize(),
|
|
||||||
},
|
|
||||||
ChannelPoint: outpoint,
|
|
||||||
Capacity: 9000,
|
|
||||||
}
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed())
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed())
|
|
||||||
copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed())
|
|
||||||
copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed())
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that both policies are returned as unknown (nil).
|
|
||||||
_, e1, e2, err := graph.FetchChannelEdgesByID(chanID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel edge")
|
|
||||||
}
|
|
||||||
if e1 != nil || e2 != nil {
|
|
||||||
t.Fatalf("channel edges not unknown")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, attempt to delete the edge from the database, again this
|
|
||||||
// should proceed without any issues.
|
|
||||||
if err := graph.DeleteChannelEdges(chanID); err != nil {
|
|
||||||
t.Fatalf("unable to delete edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that any query attempts to lookup the delete channel edge are
|
|
||||||
// properly deleted.
|
|
||||||
if _, _, _, err := graph.FetchChannelEdgesByOutpoint(&outpoint); err == nil {
|
|
||||||
t.Fatalf("channel edge not deleted")
|
|
||||||
}
|
|
||||||
if _, _, _, err := graph.FetchChannelEdgesByID(chanID); err == nil {
|
|
||||||
t.Fatalf("channel edge not deleted")
|
|
||||||
}
|
|
||||||
isZombie, _, _ := graph.IsZombieEdge(chanID)
|
|
||||||
if !isZombie {
|
|
||||||
t.Fatal("channel edge not marked as zombie")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, attempt to delete a (now) non-existent edge within the
|
|
||||||
// database, this should result in an error.
|
|
||||||
err = graph.DeleteChannelEdges(chanID)
|
|
||||||
if err != ErrEdgeNotFound {
|
|
||||||
t.Fatalf("deleting a non-existent edge should fail!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32,
|
|
||||||
node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) {
|
|
||||||
|
|
||||||
shortChanID := lnwire.ShortChannelID{
|
|
||||||
BlockHeight: height,
|
|
||||||
TxIndex: txIndex,
|
|
||||||
TxPosition: txPosition,
|
|
||||||
}
|
|
||||||
outpoint := wire.OutPoint{
|
|
||||||
Hash: rev,
|
|
||||||
Index: outPointIndex,
|
|
||||||
}
|
|
||||||
|
|
||||||
node1Pub, _ := node1.PubKey()
|
|
||||||
node2Pub, _ := node2.PubKey()
|
|
||||||
edgeInfo := ChannelEdgeInfo{
|
|
||||||
ChannelID: shortChanID.ToUint64(),
|
|
||||||
ChainHash: key,
|
|
||||||
AuthProof: &ChannelAuthProof{
|
|
||||||
NodeSig1Bytes: testSig.Serialize(),
|
|
||||||
NodeSig2Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig1Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig2Bytes: testSig.Serialize(),
|
|
||||||
},
|
|
||||||
ChannelPoint: outpoint,
|
|
||||||
Capacity: 9000,
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed())
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed())
|
|
||||||
copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed())
|
|
||||||
copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed())
|
|
||||||
|
|
||||||
return edgeInfo, shortChanID
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDisconnectBlockAtHeight checks that the pruned state of the channel
|
|
||||||
// database is what we expect after calling DisconnectBlockAtHeight.
|
|
||||||
func TestDisconnectBlockAtHeight(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
sourceNode, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create source node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.SetSourceNode(sourceNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'd like to test the insertion/deletion of edges, so we create two
|
|
||||||
// vertexes to connect.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// In addition to the fake vertexes we create some fake channel
|
|
||||||
// identifiers.
|
|
||||||
var spendOutputs []*wire.OutPoint
|
|
||||||
var blockHash chainhash.Hash
|
|
||||||
copy(blockHash[:], bytes.Repeat([]byte{1}, 32))
|
|
||||||
|
|
||||||
// Prune the graph a few times to make sure we have entries in the
|
|
||||||
// prune log.
|
|
||||||
_, err = graph.PruneGraph(spendOutputs, &blockHash, 155)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune graph: %v", err)
|
|
||||||
}
|
|
||||||
var blockHash2 chainhash.Hash
|
|
||||||
copy(blockHash2[:], bytes.Repeat([]byte{2}, 32))
|
|
||||||
|
|
||||||
_, err = graph.PruneGraph(spendOutputs, &blockHash2, 156)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune graph: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll create 3 almost identical edges, so first create a helper
|
|
||||||
// method containing all logic for doing so.
|
|
||||||
|
|
||||||
// Create an edge which has its block height at 156.
|
|
||||||
height := uint32(156)
|
|
||||||
edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2)
|
|
||||||
|
|
||||||
// Create an edge with block height 157. We give it
|
|
||||||
// maximum values for tx index and position, to make
|
|
||||||
// sure our database range scan get edges from the
|
|
||||||
// entire range.
|
|
||||||
edgeInfo2, _ := createEdge(
|
|
||||||
height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1,
|
|
||||||
node1, node2,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Create a third edge, this with a block height of 155.
|
|
||||||
edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2)
|
|
||||||
|
|
||||||
// Now add all these new edges to the database.
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo2); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo3); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Call DisconnectBlockAtHeight, which should prune every channel
|
|
||||||
// that has a funding height of 'height' or greater.
|
|
||||||
removed, err := graph.DisconnectBlockAtHeight(uint32(height))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The two edges should have been removed.
|
|
||||||
if len(removed) != 2 {
|
|
||||||
t.Fatalf("expected two edges to be removed from graph, "+
|
|
||||||
"only %d were", len(removed))
|
|
||||||
}
|
|
||||||
if removed[0].ChannelID != edgeInfo.ChannelID {
|
|
||||||
t.Fatalf("expected edge to be removed from graph")
|
|
||||||
}
|
|
||||||
if removed[1].ChannelID != edgeInfo2.ChannelID {
|
|
||||||
t.Fatalf("expected edge to be removed from graph")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The two first edges should be removed from the db.
|
|
||||||
_, _, has, isZombie, err := graph.HasChannelEdge(edgeInfo.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for edge: %v", err)
|
|
||||||
}
|
|
||||||
if has {
|
|
||||||
t.Fatalf("edge1 was not pruned from the graph")
|
|
||||||
}
|
|
||||||
if isZombie {
|
|
||||||
t.Fatal("reorged edge1 should not be marked as zombie")
|
|
||||||
}
|
|
||||||
_, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo2.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for edge: %v", err)
|
|
||||||
}
|
|
||||||
if has {
|
|
||||||
t.Fatalf("edge2 was not pruned from the graph")
|
|
||||||
}
|
|
||||||
if isZombie {
|
|
||||||
t.Fatal("reorged edge2 should not be marked as zombie")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Edge 3 should not be removed.
|
|
||||||
_, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo3.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for edge: %v", err)
|
|
||||||
}
|
|
||||||
if !has {
|
|
||||||
t.Fatalf("edge3 was pruned from the graph")
|
|
||||||
}
|
|
||||||
if isZombie {
|
|
||||||
t.Fatal("edge3 was marked as zombie")
|
|
||||||
}
|
|
||||||
|
|
||||||
// PruneTip should be set to the blockHash we specified for the block
|
|
||||||
// at height 155.
|
|
||||||
hash, h, err := graph.PruneTip()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get prune tip: %v", err)
|
|
||||||
}
|
|
||||||
if !blockHash.IsEqual(hash) {
|
|
||||||
t.Fatalf("expected best block to be %x, was %x", blockHash, hash)
|
|
||||||
}
|
|
||||||
if h != height-1 {
|
|
||||||
t.Fatalf("expected best block height to be %d, was %d", height-1, h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo,
|
|
||||||
e2 *ChannelEdgeInfo) {
|
|
||||||
|
|
||||||
if e1.ChannelID != e2.ChannelID {
|
|
||||||
t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID,
|
|
||||||
e2.ChannelID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e1.ChainHash != e2.ChainHash {
|
|
||||||
t.Fatalf("chain hashes don't match: %v vs %v", e1.ChainHash,
|
|
||||||
e2.ChainHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(e1.NodeKey1Bytes[:], e2.NodeKey1Bytes[:]) {
|
|
||||||
t.Fatalf("nodekey1 doesn't match")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(e1.NodeKey2Bytes[:], e2.NodeKey2Bytes[:]) {
|
|
||||||
t.Fatalf("nodekey2 doesn't match")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(e1.BitcoinKey1Bytes[:], e2.BitcoinKey1Bytes[:]) {
|
|
||||||
t.Fatalf("bitcoinkey1 doesn't match")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(e1.BitcoinKey2Bytes[:], e2.BitcoinKey2Bytes[:]) {
|
|
||||||
t.Fatalf("bitcoinkey2 doesn't match")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(e1.Features, e2.Features) {
|
|
||||||
t.Fatalf("features doesn't match: %x vs %x", e1.Features,
|
|
||||||
e2.Features)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes) {
|
|
||||||
t.Fatalf("nodesig1 doesn't match: %v vs %v",
|
|
||||||
spew.Sdump(e1.AuthProof.NodeSig1Bytes),
|
|
||||||
spew.Sdump(e2.AuthProof.NodeSig1Bytes))
|
|
||||||
}
|
|
||||||
if !bytes.Equal(e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes) {
|
|
||||||
t.Fatalf("nodesig2 doesn't match")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(e1.AuthProof.BitcoinSig1Bytes, e2.AuthProof.BitcoinSig1Bytes) {
|
|
||||||
t.Fatalf("bitcoinsig1 doesn't match")
|
|
||||||
}
|
|
||||||
if !bytes.Equal(e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes) {
|
|
||||||
t.Fatalf("bitcoinsig2 doesn't match")
|
|
||||||
}
|
|
||||||
|
|
||||||
if e1.ChannelPoint != e2.ChannelPoint {
|
|
||||||
t.Fatalf("channel point match: %v vs %v", e1.ChannelPoint,
|
|
||||||
e2.ChannelPoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e1.Capacity != e2.Capacity {
|
|
||||||
t.Fatalf("capacity doesn't match: %v vs %v", e1.Capacity,
|
|
||||||
e2.Capacity)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(e1.ExtraOpaqueData, e2.ExtraOpaqueData) {
|
|
||||||
t.Fatalf("extra data doesn't match: %v vs %v",
|
|
||||||
e2.ExtraOpaqueData, e2.ExtraOpaqueData)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo,
|
|
||||||
*ChannelEdgePolicy, *ChannelEdgePolicy) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
firstNode *LightningNode
|
|
||||||
secondNode *LightningNode
|
|
||||||
)
|
|
||||||
if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 {
|
|
||||||
firstNode = node1
|
|
||||||
secondNode = node2
|
|
||||||
} else {
|
|
||||||
firstNode = node2
|
|
||||||
secondNode = node1
|
|
||||||
}
|
|
||||||
|
|
||||||
// In addition to the fake vertexes we create some fake channel
|
|
||||||
// identifiers.
|
|
||||||
chanID := uint64(prand.Int63())
|
|
||||||
outpoint := wire.OutPoint{
|
|
||||||
Hash: rev,
|
|
||||||
Index: 9,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the new edge to the database, this should proceed without any
|
|
||||||
// errors.
|
|
||||||
edgeInfo := &ChannelEdgeInfo{
|
|
||||||
ChannelID: chanID,
|
|
||||||
ChainHash: key,
|
|
||||||
AuthProof: &ChannelAuthProof{
|
|
||||||
NodeSig1Bytes: testSig.Serialize(),
|
|
||||||
NodeSig2Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig1Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig2Bytes: testSig.Serialize(),
|
|
||||||
},
|
|
||||||
ChannelPoint: outpoint,
|
|
||||||
Capacity: 1000,
|
|
||||||
ExtraOpaqueData: []byte("new unknown feature"),
|
|
||||||
}
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], firstNode.PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], secondNode.PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:])
|
|
||||||
|
|
||||||
edge1 := &ChannelEdgePolicy{
|
|
||||||
SigBytes: testSig.Serialize(),
|
|
||||||
ChannelID: chanID,
|
|
||||||
LastUpdate: time.Unix(433453, 0),
|
|
||||||
MessageFlags: 1,
|
|
||||||
ChannelFlags: 0,
|
|
||||||
TimeLockDelta: 99,
|
|
||||||
MinHTLC: 2342135,
|
|
||||||
MaxHTLC: 13928598,
|
|
||||||
FeeBaseMSat: 4352345,
|
|
||||||
FeeProportionalMillionths: 3452352,
|
|
||||||
Node: secondNode,
|
|
||||||
ExtraOpaqueData: []byte("new unknown feature2"),
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
edge2 := &ChannelEdgePolicy{
|
|
||||||
SigBytes: testSig.Serialize(),
|
|
||||||
ChannelID: chanID,
|
|
||||||
LastUpdate: time.Unix(124234, 0),
|
|
||||||
MessageFlags: 1,
|
|
||||||
ChannelFlags: 1,
|
|
||||||
TimeLockDelta: 99,
|
|
||||||
MinHTLC: 2342135,
|
|
||||||
MaxHTLC: 13928598,
|
|
||||||
FeeBaseMSat: 4352345,
|
|
||||||
FeeProportionalMillionths: 90392423,
|
|
||||||
Node: firstNode,
|
|
||||||
ExtraOpaqueData: []byte("new unknown feature1"),
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
|
|
||||||
return edgeInfo, edge1, edge2
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEdgeInfoUpdates(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test the update of edges inserted into the database, so
|
|
||||||
// we create two vertexes to connect.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create an edge and add it to the db.
|
|
||||||
edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2)
|
|
||||||
|
|
||||||
// Make sure inserting the policy at this point, before the edge info
|
|
||||||
// is added, will fail.
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound {
|
|
||||||
t.Fatalf("expected ErrEdgeNotFound, got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the edge info.
|
|
||||||
if err := graph.AddChannelEdge(edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanID := edgeInfo.ChannelID
|
|
||||||
outpoint := edgeInfo.ChannelPoint
|
|
||||||
|
|
||||||
// Next, insert both edge policies into the database, they should both
|
|
||||||
// be inserted without any issues.
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for existence of the edge within the database, it should be
|
|
||||||
// found.
|
|
||||||
_, _, found, isZombie, err := graph.HasChannelEdge(chanID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for edge: %v", err)
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("graph should have of inserted edge")
|
|
||||||
}
|
|
||||||
if isZombie {
|
|
||||||
t.Fatal("live edge should not be marked as zombie")
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should also be able to retrieve the channelID only knowing the
|
|
||||||
// channel point of the channel.
|
|
||||||
dbChanID, err := graph.ChannelID(&outpoint)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to retrieve channel ID: %v", err)
|
|
||||||
}
|
|
||||||
if dbChanID != chanID {
|
|
||||||
t.Fatalf("chan ID's mismatch, expected %v got %v", dbChanID,
|
|
||||||
chanID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the edges inserted, perform some queries to ensure that they've
|
|
||||||
// been inserted properly.
|
|
||||||
dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel by ID: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge1, edge1); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge2, edge2); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo)
|
|
||||||
|
|
||||||
// Next, attempt to query the channel edges according to the outpoint
|
|
||||||
// of the channel.
|
|
||||||
dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint(&outpoint)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel by ID: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge1, edge1); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge2, edge2); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy {
|
|
||||||
update := prand.Int63()
|
|
||||||
|
|
||||||
return newEdgePolicy(chanID, op, db, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB,
|
|
||||||
updateTime int64) *ChannelEdgePolicy {
|
|
||||||
|
|
||||||
return &ChannelEdgePolicy{
|
|
||||||
ChannelID: chanID,
|
|
||||||
LastUpdate: time.Unix(updateTime, 0),
|
|
||||||
MessageFlags: 1,
|
|
||||||
ChannelFlags: 0,
|
|
||||||
TimeLockDelta: uint16(prand.Int63()),
|
|
||||||
MinHTLC: lnwire.MilliSatoshi(prand.Int63()),
|
|
||||||
MaxHTLC: lnwire.MilliSatoshi(prand.Int63()),
|
|
||||||
FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()),
|
|
||||||
FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()),
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGraphTraversal(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test some of the graph traversal capabilities within
|
|
||||||
// the DB, so we'll create a series of fake nodes to insert into the
|
|
||||||
// graph.
|
|
||||||
const numNodes = 20
|
|
||||||
nodes := make([]*LightningNode, numNodes)
|
|
||||||
nodeIndex := map[string]struct{}{}
|
|
||||||
for i := 0; i < numNodes; i++ {
|
|
||||||
node, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nodes[i] = node
|
|
||||||
nodeIndex[node.Alias] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add each of the nodes into the graph, they should be inserted
|
|
||||||
// without error.
|
|
||||||
for _, node := range nodes {
|
|
||||||
if err := graph.AddLightningNode(node); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over each node as returned by the graph, if all nodes are
|
|
||||||
// reached, then the map created above should be empty.
|
|
||||||
err = graph.ForEachNode(nil, func(_ *bbolt.Tx, node *LightningNode) error {
|
|
||||||
delete(nodeIndex, node.Alias)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("for each failure: %v", err)
|
|
||||||
}
|
|
||||||
if len(nodeIndex) != 0 {
|
|
||||||
t.Fatalf("all nodes not reached within ForEach")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine which node is "smaller", we'll need this in order to
|
|
||||||
// properly create the edges for the graph.
|
|
||||||
var firstNode, secondNode *LightningNode
|
|
||||||
if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 {
|
|
||||||
firstNode = nodes[0]
|
|
||||||
secondNode = nodes[1]
|
|
||||||
} else {
|
|
||||||
firstNode = nodes[0]
|
|
||||||
secondNode = nodes[1]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create 5 channels between the first two nodes we generated above.
|
|
||||||
const numChannels = 5
|
|
||||||
chanIndex := map[uint64]struct{}{}
|
|
||||||
for i := 0; i < numChannels; i++ {
|
|
||||||
txHash := sha256.Sum256([]byte{byte(i)})
|
|
||||||
chanID := uint64(i + 1)
|
|
||||||
op := wire.OutPoint{
|
|
||||||
Hash: txHash,
|
|
||||||
Index: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeInfo := ChannelEdgeInfo{
|
|
||||||
ChannelID: chanID,
|
|
||||||
ChainHash: key,
|
|
||||||
AuthProof: &ChannelAuthProof{
|
|
||||||
NodeSig1Bytes: testSig.Serialize(),
|
|
||||||
NodeSig2Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig1Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig2Bytes: testSig.Serialize(),
|
|
||||||
},
|
|
||||||
ChannelPoint: op,
|
|
||||||
Capacity: 1000,
|
|
||||||
}
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:])
|
|
||||||
err := graph.AddChannelEdge(&edgeInfo)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create and add an edge with random data that points from
|
|
||||||
// node1 -> node2.
|
|
||||||
edge := randEdgePolicy(chanID, op, db)
|
|
||||||
edge.ChannelFlags = 0
|
|
||||||
edge.Node = secondNode
|
|
||||||
edge.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create another random edge that points from node2 -> node1
|
|
||||||
// this time.
|
|
||||||
edge = randEdgePolicy(chanID, op, db)
|
|
||||||
edge.ChannelFlags = 1
|
|
||||||
edge.Node = firstNode
|
|
||||||
edge.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanIndex[chanID] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate through all the known channels within the graph DB, once
|
|
||||||
// again if the map is empty that indicates that all edges have
|
|
||||||
// properly been reached.
|
|
||||||
err = graph.ForEachChannel(func(ei *ChannelEdgeInfo, _ *ChannelEdgePolicy,
|
|
||||||
_ *ChannelEdgePolicy) error {
|
|
||||||
|
|
||||||
delete(chanIndex, ei.ChannelID)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("for each failure: %v", err)
|
|
||||||
}
|
|
||||||
if len(chanIndex) != 0 {
|
|
||||||
t.Fatalf("all edges not reached within ForEach")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we want to test the ability to iterate over all the
|
|
||||||
// outgoing channels for a particular node.
|
|
||||||
numNodeChans := 0
|
|
||||||
err = firstNode.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo,
|
|
||||||
outEdge, inEdge *ChannelEdgePolicy) error {
|
|
||||||
|
|
||||||
// All channels between first and second node should have fully
|
|
||||||
// (both sides) specified policies.
|
|
||||||
if inEdge == nil || outEdge == nil {
|
|
||||||
return fmt.Errorf("channel policy not present")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each should indicate that it's outgoing (pointed
|
|
||||||
// towards the second node).
|
|
||||||
if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) {
|
|
||||||
return fmt.Errorf("wrong outgoing edge")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The incoming edge should also indicate that it's pointing to
|
|
||||||
// the origin node.
|
|
||||||
if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) {
|
|
||||||
return fmt.Errorf("wrong outgoing edge")
|
|
||||||
}
|
|
||||||
|
|
||||||
numNodeChans++
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("for each failure: %v", err)
|
|
||||||
}
|
|
||||||
if numNodeChans != numChannels {
|
|
||||||
t.Fatalf("all edges for node not reached within ForEach: "+
|
|
||||||
"expected %v, got %v", numChannels, numNodeChans)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash,
|
|
||||||
blockHeight uint32) {
|
|
||||||
|
|
||||||
pruneHash, pruneHeight, err := graph.PruneTip()
|
|
||||||
if err != nil {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: unable to fetch prune tip: %v", line, err)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(blockHash[:], pruneHash[:]) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line: %v, prune tips don't match, expected %x got %x",
|
|
||||||
line, blockHash, pruneHash)
|
|
||||||
}
|
|
||||||
if pruneHeight != blockHeight {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: prune heights don't match, expected %v "+
|
|
||||||
"got %v", line, blockHeight, pruneHeight)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertNumChans(t *testing.T, graph *ChannelGraph, n int) {
|
|
||||||
numChans := 0
|
|
||||||
if err := graph.ForEachChannel(func(*ChannelEdgeInfo, *ChannelEdgePolicy,
|
|
||||||
*ChannelEdgePolicy) error {
|
|
||||||
|
|
||||||
numChans++
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: unable to scan channels: %v", line, err)
|
|
||||||
}
|
|
||||||
if numChans != n {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: expected %v chans instead have %v", line,
|
|
||||||
n, numChans)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) {
|
|
||||||
numNodes := 0
|
|
||||||
err := graph.ForEachNode(nil, func(_ *bbolt.Tx, _ *LightningNode) error {
|
|
||||||
numNodes++
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: unable to scan nodes: %v", line, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if numNodes != n {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: chan views don't match", line)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanViewSet := make(map[wire.OutPoint]struct{})
|
|
||||||
for _, op := range a {
|
|
||||||
chanViewSet[op.OutPoint] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, op := range b {
|
|
||||||
if _, ok := chanViewSet[op.OutPoint]; !ok {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: chanPoint(%v) not found in first "+
|
|
||||||
"view", line, op)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: chan views don't match", line)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanViewSet := make(map[wire.OutPoint]struct{})
|
|
||||||
for _, op := range a {
|
|
||||||
chanViewSet[op.OutPoint] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, op := range b {
|
|
||||||
if _, ok := chanViewSet[*op]; !ok {
|
|
||||||
_, _, line, _ := runtime.Caller(1)
|
|
||||||
t.Fatalf("line %v: chanPoint(%v) not found in first "+
|
|
||||||
"view", line, op)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGraphPruning(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
sourceNode, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create source node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.SetSourceNode(sourceNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// As initial set up for the test, we'll create a graph with 5 vertexes
|
|
||||||
// and enough edges to create a fully connected graph. The graph will
|
|
||||||
// be rather simple, representing a straight line.
|
|
||||||
const numNodes = 5
|
|
||||||
graphNodes := make([]*LightningNode, numNodes)
|
|
||||||
for i := 0; i < numNodes; i++ {
|
|
||||||
node, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := graph.AddLightningNode(node); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graphNodes[i] = node
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the vertexes created, we'll next create a series of channels
|
|
||||||
// between them.
|
|
||||||
channelPoints := make([]*wire.OutPoint, 0, numNodes-1)
|
|
||||||
edgePoints := make([]EdgePoint, 0, numNodes-1)
|
|
||||||
for i := 0; i < numNodes-1; i++ {
|
|
||||||
txHash := sha256.Sum256([]byte{byte(i)})
|
|
||||||
chanID := uint64(i + 1)
|
|
||||||
op := wire.OutPoint{
|
|
||||||
Hash: txHash,
|
|
||||||
Index: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
channelPoints = append(channelPoints, &op)
|
|
||||||
|
|
||||||
edgeInfo := ChannelEdgeInfo{
|
|
||||||
ChannelID: chanID,
|
|
||||||
ChainHash: key,
|
|
||||||
AuthProof: &ChannelAuthProof{
|
|
||||||
NodeSig1Bytes: testSig.Serialize(),
|
|
||||||
NodeSig2Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig1Bytes: testSig.Serialize(),
|
|
||||||
BitcoinSig2Bytes: testSig.Serialize(),
|
|
||||||
},
|
|
||||||
ChannelPoint: op,
|
|
||||||
Capacity: 1000,
|
|
||||||
}
|
|
||||||
copy(edgeInfo.NodeKey1Bytes[:], graphNodes[i].PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.NodeKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.BitcoinKey1Bytes[:], graphNodes[i].PubKeyBytes[:])
|
|
||||||
copy(edgeInfo.BitcoinKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:])
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pkScript, err := genMultiSigP2WSH(
|
|
||||||
edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:],
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to gen multi-sig p2wsh: %v", err)
|
|
||||||
}
|
|
||||||
edgePoints = append(edgePoints, EdgePoint{
|
|
||||||
FundingPkScript: pkScript,
|
|
||||||
OutPoint: op,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create and add an edge with random data that points from
|
|
||||||
// node_i -> node_i+1
|
|
||||||
edge := randEdgePolicy(chanID, op, db)
|
|
||||||
edge.ChannelFlags = 0
|
|
||||||
edge.Node = graphNodes[i]
|
|
||||||
edge.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create another random edge that points from node_i+1 ->
|
|
||||||
// node_i this time.
|
|
||||||
edge = randEdgePolicy(chanID, op, db)
|
|
||||||
edge.ChannelFlags = 1
|
|
||||||
edge.Node = graphNodes[i]
|
|
||||||
edge.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// With all the channel points added, we'll consult the graph to ensure
|
|
||||||
// it has the same channel view as the one we just constructed.
|
|
||||||
channelView, err := graph.ChannelView()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get graph channel view: %v", err)
|
|
||||||
}
|
|
||||||
assertChanViewEqual(t, channelView, edgePoints)
|
|
||||||
|
|
||||||
// Now with our test graph created, we can test the pruning
|
|
||||||
// capabilities of the channel graph.
|
|
||||||
|
|
||||||
// First we create a mock block that ends up closing the first two
|
|
||||||
// channels.
|
|
||||||
var blockHash chainhash.Hash
|
|
||||||
copy(blockHash[:], bytes.Repeat([]byte{1}, 32))
|
|
||||||
blockHeight := uint32(1)
|
|
||||||
block := channelPoints[:2]
|
|
||||||
prunedChans, err := graph.PruneGraph(block, &blockHash, blockHeight)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune graph: %v", err)
|
|
||||||
}
|
|
||||||
if len(prunedChans) != 2 {
|
|
||||||
t.Fatalf("incorrect number of channels pruned: "+
|
|
||||||
"expected %v, got %v", 2, prunedChans)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now ensure that the prune tip has been updated.
|
|
||||||
assertPruneTip(t, graph, &blockHash, blockHeight)
|
|
||||||
|
|
||||||
// Count up the number of channels known within the graph, only 2
|
|
||||||
// should be remaining.
|
|
||||||
assertNumChans(t, graph, 2)
|
|
||||||
|
|
||||||
// Those channels should also be missing from the channel view.
|
|
||||||
channelView, err = graph.ChannelView()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get graph channel view: %v", err)
|
|
||||||
}
|
|
||||||
assertChanViewEqualChanPoints(t, channelView, channelPoints[2:])
|
|
||||||
|
|
||||||
// Next we'll create a block that doesn't close any channels within the
|
|
||||||
// graph to test the negative error case.
|
|
||||||
fakeHash := sha256.Sum256([]byte("test prune"))
|
|
||||||
nonChannel := &wire.OutPoint{
|
|
||||||
Hash: fakeHash,
|
|
||||||
Index: 9,
|
|
||||||
}
|
|
||||||
blockHash = sha256.Sum256(blockHash[:])
|
|
||||||
blockHeight = 2
|
|
||||||
prunedChans, err = graph.PruneGraph(
|
|
||||||
[]*wire.OutPoint{nonChannel}, &blockHash, blockHeight,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune graph: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// No channels should have been detected as pruned.
|
|
||||||
if len(prunedChans) != 0 {
|
|
||||||
t.Fatalf("channels were pruned but shouldn't have been")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Once again, the prune tip should have been updated. We should still
|
|
||||||
// see both channels and their participants, along with the source node.
|
|
||||||
assertPruneTip(t, graph, &blockHash, blockHeight)
|
|
||||||
assertNumChans(t, graph, 2)
|
|
||||||
assertNumNodes(t, graph, 4)
|
|
||||||
|
|
||||||
// Finally, create a block that prunes the remainder of the channels
|
|
||||||
// from the graph.
|
|
||||||
blockHash = sha256.Sum256(blockHash[:])
|
|
||||||
blockHeight = 3
|
|
||||||
prunedChans, err = graph.PruneGraph(
|
|
||||||
channelPoints[2:], &blockHash, blockHeight,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune graph: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The remainder of the channels should have been pruned from the
|
|
||||||
// graph.
|
|
||||||
if len(prunedChans) != 2 {
|
|
||||||
t.Fatalf("incorrect number of channels pruned: "+
|
|
||||||
"expected %v, got %v", 2, len(prunedChans))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The prune tip should be updated, no channels should be found, and
|
|
||||||
// only the source node should remain within the current graph.
|
|
||||||
assertPruneTip(t, graph, &blockHash, blockHeight)
|
|
||||||
assertNumChans(t, graph, 0)
|
|
||||||
assertNumNodes(t, graph, 1)
|
|
||||||
|
|
||||||
// Finally, the channel view at this point in the graph should now be
|
|
||||||
// completely empty. Those channels should also be missing from the
|
|
||||||
// channel view.
|
|
||||||
channelView, err = graph.ChannelView()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get graph channel view: %v", err)
|
|
||||||
}
|
|
||||||
if len(channelView) != 0 {
|
|
||||||
t.Fatalf("channel view should be empty, instead have: %v",
|
|
||||||
channelView)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestHighestChanID tests that we're able to properly retrieve the highest
|
|
||||||
// known channel ID in the database.
|
|
||||||
func TestHighestChanID(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// If we don't yet have any channels in the database, then we should
|
|
||||||
// get a channel ID of zero if we ask for the highest channel ID.
|
|
||||||
bestID, err := graph.HighestChanID()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get highest ID: %v", err)
|
|
||||||
}
|
|
||||||
if bestID != 0 {
|
|
||||||
t.Fatalf("best ID w/ no chan should be zero, is instead: %v",
|
|
||||||
bestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll insert two channels into the database, with each channel
|
|
||||||
// connecting the same two nodes.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The first channel with be at height 10, while the other will be at
|
|
||||||
// height 100.
|
|
||||||
edge1, _ := createEdge(10, 0, 0, 0, node1, node2)
|
|
||||||
edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2)
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&edge1); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddChannelEdge(&edge2); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the edges has been inserted, we'll query for the highest
|
|
||||||
// known channel ID in the database.
|
|
||||||
bestID, err = graph.HighestChanID()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get highest ID: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bestID != chanID2.ToUint64() {
|
|
||||||
t.Fatalf("expected %v got %v for best chan ID: ",
|
|
||||||
chanID2.ToUint64(), bestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we add another edge, then the current best chan ID should be
|
|
||||||
// updated as well.
|
|
||||||
edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2)
|
|
||||||
if err := graph.AddChannelEdge(&edge3); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
bestID, err = graph.HighestChanID()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get highest ID: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bestID != chanID3.ToUint64() {
|
|
||||||
t.Fatalf("expected %v got %v for best chan ID: ",
|
|
||||||
chanID3.ToUint64(), bestID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known
|
|
||||||
// channel updates within a specific time horizon. It also tests that upon
|
|
||||||
// insertion of a new edge, the edge update index is updated properly.
|
|
||||||
func TestChanUpdatesInHorizon(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// If we issue an arbitrary query before any channel updates are
|
|
||||||
// inserted in the database, we should get zero results.
|
|
||||||
chanUpdates, err := graph.ChanUpdatesInHorizon(
|
|
||||||
time.Unix(999, 0), time.Unix(9999, 0),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to updates for updates: %v", err)
|
|
||||||
}
|
|
||||||
if len(chanUpdates) != 0 {
|
|
||||||
t.Fatalf("expected 0 chan updates, instead got %v",
|
|
||||||
len(chanUpdates))
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll start by creating two nodes which will seed our test graph.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now create 10 channels between the two nodes, with update
|
|
||||||
// times 10 seconds after each other.
|
|
||||||
const numChans = 10
|
|
||||||
startTime := time.Unix(1234, 0)
|
|
||||||
endTime := startTime
|
|
||||||
edges := make([]ChannelEdge, 0, numChans)
|
|
||||||
for i := 0; i < numChans; i++ {
|
|
||||||
txHash := sha256.Sum256([]byte{byte(i)})
|
|
||||||
op := wire.OutPoint{
|
|
||||||
Hash: txHash,
|
|
||||||
Index: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, chanID := createEdge(
|
|
||||||
uint32(i*10), 0, 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&channel); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edge1UpdateTime := endTime
|
|
||||||
edge2UpdateTime := edge1UpdateTime.Add(time.Second)
|
|
||||||
endTime = endTime.Add(time.Second * 10)
|
|
||||||
|
|
||||||
edge1 := newEdgePolicy(
|
|
||||||
chanID.ToUint64(), op, db, edge1UpdateTime.Unix(),
|
|
||||||
)
|
|
||||||
edge1.ChannelFlags = 0
|
|
||||||
edge1.Node = node2
|
|
||||||
edge1.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edge2 := newEdgePolicy(
|
|
||||||
chanID.ToUint64(), op, db, edge2UpdateTime.Unix(),
|
|
||||||
)
|
|
||||||
edge2.ChannelFlags = 1
|
|
||||||
edge2.Node = node1
|
|
||||||
edge2.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edges = append(edges, ChannelEdge{
|
|
||||||
Info: &channel,
|
|
||||||
Policy1: edge1,
|
|
||||||
Policy2: edge2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// With our channels loaded, we'll now start our series of queries.
|
|
||||||
queryCases := []struct {
|
|
||||||
start time.Time
|
|
||||||
end time.Time
|
|
||||||
|
|
||||||
resp []ChannelEdge
|
|
||||||
}{
|
|
||||||
// If we query for a time range that's strictly below our set
|
|
||||||
// of updates, then we'll get an empty result back.
|
|
||||||
{
|
|
||||||
start: time.Unix(100, 0),
|
|
||||||
end: time.Unix(200, 0),
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for a time range that's well beyond our set of
|
|
||||||
// updates, we should get an empty set of results back.
|
|
||||||
{
|
|
||||||
start: time.Unix(99999, 0),
|
|
||||||
end: time.Unix(999999, 0),
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for the start time, and 10 seconds directly
|
|
||||||
// after it, we should only get a single update, that first
|
|
||||||
// one.
|
|
||||||
{
|
|
||||||
start: time.Unix(1234, 0),
|
|
||||||
end: startTime.Add(time.Second * 10),
|
|
||||||
|
|
||||||
resp: []ChannelEdge{edges[0]},
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we add 10 seconds past the first update, and then
|
|
||||||
// subtract 10 from the last update, then we should only get
|
|
||||||
// the 8 edges in the middle.
|
|
||||||
{
|
|
||||||
start: startTime.Add(time.Second * 10),
|
|
||||||
end: endTime.Add(-time.Second * 10),
|
|
||||||
|
|
||||||
resp: edges[1:9],
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we use the start and end time as is, we should get the
|
|
||||||
// entire range.
|
|
||||||
{
|
|
||||||
start: startTime,
|
|
||||||
end: endTime,
|
|
||||||
|
|
||||||
resp: edges,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, queryCase := range queryCases {
|
|
||||||
resp, err := graph.ChanUpdatesInHorizon(
|
|
||||||
queryCase.start, queryCase.end,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for updates: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp) != len(queryCase.resp) {
|
|
||||||
t.Fatalf("expected %v chans, got %v chans",
|
|
||||||
len(queryCase.resp), len(resp))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(resp); i++ {
|
|
||||||
chanExp := queryCase.resp[i]
|
|
||||||
chanRet := resp[i]
|
|
||||||
|
|
||||||
assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info)
|
|
||||||
|
|
||||||
err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
compareEdgePolicies(chanExp.Policy2, chanRet.Policy2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve
|
|
||||||
// the most recent node updates within a particular time horizon.
|
|
||||||
func TestNodeUpdatesInHorizon(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
startTime := time.Unix(1234, 0)
|
|
||||||
endTime := startTime
|
|
||||||
|
|
||||||
// If we issue an arbitrary query before we insert any nodes into the
|
|
||||||
// database, then we shouldn't get any results back.
|
|
||||||
nodeUpdates, err := graph.NodeUpdatesInHorizon(
|
|
||||||
time.Unix(999, 0), time.Unix(9999, 0),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for node updates: %v", err)
|
|
||||||
}
|
|
||||||
if len(nodeUpdates) != 0 {
|
|
||||||
t.Fatalf("expected 0 node updates, instead got %v",
|
|
||||||
len(nodeUpdates))
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll create 10 node announcements, each with an update timestamp 10
|
|
||||||
// seconds after the other.
|
|
||||||
const numNodes = 10
|
|
||||||
nodeAnns := make([]LightningNode, 0, numNodes)
|
|
||||||
for i := 0; i < numNodes; i++ {
|
|
||||||
nodeAnn, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test vertex: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The node ann will use the current end time as its last
|
|
||||||
// update them, then we'll add 10 seconds in order to create
|
|
||||||
// the proper update time for the next node announcement.
|
|
||||||
updateTime := endTime
|
|
||||||
endTime = updateTime.Add(time.Second * 10)
|
|
||||||
|
|
||||||
nodeAnn.LastUpdate = updateTime
|
|
||||||
|
|
||||||
nodeAnns = append(nodeAnns, *nodeAnn)
|
|
||||||
|
|
||||||
if err := graph.AddLightningNode(nodeAnn); err != nil {
|
|
||||||
t.Fatalf("unable to add lightning node: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
queryCases := []struct {
|
|
||||||
start time.Time
|
|
||||||
end time.Time
|
|
||||||
|
|
||||||
resp []LightningNode
|
|
||||||
}{
|
|
||||||
// If we query for a time range that's strictly below our set
|
|
||||||
// of updates, then we'll get an empty result back.
|
|
||||||
{
|
|
||||||
start: time.Unix(100, 0),
|
|
||||||
end: time.Unix(200, 0),
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for a time range that's well beyond our set of
|
|
||||||
// updates, we should get an empty set of results back.
|
|
||||||
{
|
|
||||||
start: time.Unix(99999, 0),
|
|
||||||
end: time.Unix(999999, 0),
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we skip he first time epoch with out start time, then we
|
|
||||||
// should get back every now but the first.
|
|
||||||
{
|
|
||||||
start: startTime.Add(time.Second * 10),
|
|
||||||
end: endTime,
|
|
||||||
|
|
||||||
resp: nodeAnns[1:],
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for the range as is, we should get all 10
|
|
||||||
// announcements back.
|
|
||||||
{
|
|
||||||
start: startTime,
|
|
||||||
end: endTime,
|
|
||||||
|
|
||||||
resp: nodeAnns,
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we reduce the ending time by 10 seconds, then we should
|
|
||||||
// get all but the last node we inserted.
|
|
||||||
{
|
|
||||||
start: startTime,
|
|
||||||
end: endTime.Add(-time.Second * 10),
|
|
||||||
|
|
||||||
resp: nodeAnns[:9],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, queryCase := range queryCases {
|
|
||||||
resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query for nodes: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp) != len(queryCase.resp) {
|
|
||||||
t.Fatalf("expected %v nodes, got %v nodes",
|
|
||||||
len(queryCase.resp), len(resp))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(resp); i++ {
|
|
||||||
err := compareNodes(&queryCase.resp[i], &resp[i])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFilterKnownChanIDs tests that we're able to properly perform the set
|
|
||||||
// differences of an incoming set of channel ID's, and those that we already
|
|
||||||
// know of on disk.
|
|
||||||
func TestFilterKnownChanIDs(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// If we try to filter out a set of channel ID's before we even know of
|
|
||||||
// any channels, then we should get the entire set back.
|
|
||||||
preChanIDs := []uint64{1, 2, 3, 4}
|
|
||||||
filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to filter chan IDs: %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(preChanIDs, filteredIDs) {
|
|
||||||
t.Fatalf("chan IDs shouldn't have been filtered!")
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll start by creating two nodes which will seed our test graph.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll add 5 channel ID's to the graph, each of them having a
|
|
||||||
// block height 10 blocks after the previous.
|
|
||||||
const numChans = 5
|
|
||||||
chanIDs := make([]uint64, 0, numChans)
|
|
||||||
for i := 0; i < numChans; i++ {
|
|
||||||
channel, chanID := createEdge(
|
|
||||||
uint32(i*10), 0, 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&channel); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanIDs = append(chanIDs, chanID.ToUint64())
|
|
||||||
}
|
|
||||||
|
|
||||||
const numZombies = 5
|
|
||||||
zombieIDs := make([]uint64, 0, numZombies)
|
|
||||||
for i := 0; i < numZombies; i++ {
|
|
||||||
channel, chanID := createEdge(
|
|
||||||
uint32(i*10+1), 0, 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
if err := graph.AddChannelEdge(&channel); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
err := graph.DeleteChannelEdges(channel.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to mark edge zombie: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
zombieIDs = append(zombieIDs, chanID.ToUint64())
|
|
||||||
}
|
|
||||||
|
|
||||||
queryCases := []struct {
|
|
||||||
queryIDs []uint64
|
|
||||||
|
|
||||||
resp []uint64
|
|
||||||
}{
|
|
||||||
// If we attempt to filter out all chanIDs we know of, the
|
|
||||||
// response should be the empty set.
|
|
||||||
{
|
|
||||||
queryIDs: chanIDs,
|
|
||||||
},
|
|
||||||
// If we attempt to filter out all zombies that we know of, the
|
|
||||||
// response should be the empty set.
|
|
||||||
{
|
|
||||||
queryIDs: zombieIDs,
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for a set of ID's that we didn't insert, we
|
|
||||||
// should get the same set back.
|
|
||||||
{
|
|
||||||
queryIDs: []uint64{99, 100},
|
|
||||||
resp: []uint64{99, 100},
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for a super-set of our the chan ID's inserted,
|
|
||||||
// we should only get those new chanIDs back.
|
|
||||||
{
|
|
||||||
queryIDs: append(chanIDs, []uint64{99, 101}...),
|
|
||||||
resp: []uint64{99, 101},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, queryCase := range queryCases {
|
|
||||||
resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to filter chan IDs: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(resp, queryCase.resp) {
|
|
||||||
t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp),
|
|
||||||
spew.Sdump(resp))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFilterChannelRange tests that we're able to properly retrieve the full
|
|
||||||
// set of short channel ID's for a given block range.
|
|
||||||
func TestFilterChannelRange(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'll first populate our graph with two nodes. All channels created
|
|
||||||
// below will be made between these two nodes.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we try to filter a channel range before we have any channels
|
|
||||||
// inserted, we should get an empty slice of results.
|
|
||||||
resp, err := graph.FilterChannelRange(10, 100)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to filter channels: %v", err)
|
|
||||||
}
|
|
||||||
if len(resp) != 0 {
|
|
||||||
t.Fatalf("expected zero chans, instead got %v", len(resp))
|
|
||||||
}
|
|
||||||
|
|
||||||
// To start, we'll create a set of channels, each mined in a block 10
|
|
||||||
// blocks after the prior one.
|
|
||||||
startHeight := uint32(100)
|
|
||||||
endHeight := startHeight
|
|
||||||
const numChans = 10
|
|
||||||
chanIDs := make([]uint64, 0, numChans)
|
|
||||||
for i := 0; i < numChans; i++ {
|
|
||||||
chanHeight := endHeight
|
|
||||||
channel, chanID := createEdge(
|
|
||||||
uint32(chanHeight), uint32(i+1), 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&channel); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanIDs = append(chanIDs, chanID.ToUint64())
|
|
||||||
|
|
||||||
endHeight += 10
|
|
||||||
}
|
|
||||||
|
|
||||||
// With our channels inserted, we'll construct a series of queries that
|
|
||||||
// we'll execute below in order to exercise the features of the
|
|
||||||
// FilterKnownChanIDs method.
|
|
||||||
queryCases := []struct {
|
|
||||||
startHeight uint32
|
|
||||||
endHeight uint32
|
|
||||||
|
|
||||||
resp []uint64
|
|
||||||
}{
|
|
||||||
// If we query for the entire range, then we should get the same
|
|
||||||
// set of short channel IDs back.
|
|
||||||
{
|
|
||||||
startHeight: startHeight,
|
|
||||||
endHeight: endHeight,
|
|
||||||
|
|
||||||
resp: chanIDs,
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for a range of channels right before our range, we
|
|
||||||
// shouldn't get any results back.
|
|
||||||
{
|
|
||||||
startHeight: 0,
|
|
||||||
endHeight: 10,
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we only query for the last height (range wise), we should
|
|
||||||
// only get that last channel.
|
|
||||||
{
|
|
||||||
startHeight: endHeight - 10,
|
|
||||||
endHeight: endHeight - 10,
|
|
||||||
|
|
||||||
resp: chanIDs[9:],
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we query for just the first height, we should only get a
|
|
||||||
// single channel back (the first one).
|
|
||||||
{
|
|
||||||
startHeight: startHeight,
|
|
||||||
endHeight: startHeight,
|
|
||||||
|
|
||||||
resp: chanIDs[:1],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for i, queryCase := range queryCases {
|
|
||||||
resp, err := graph.FilterChannelRange(
|
|
||||||
queryCase.startHeight, queryCase.endHeight,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to issue range query: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(resp, queryCase.resp) {
|
|
||||||
t.Fatalf("case #%v: expected %v, got %v", i,
|
|
||||||
queryCase.resp, resp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestFetchChanInfos tests that we're able to properly retrieve the full set
|
|
||||||
// of ChannelEdge structs for a given set of short channel ID's.
|
|
||||||
func TestFetchChanInfos(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'll first populate our graph with two nodes. All channels created
|
|
||||||
// below will be made between these two nodes.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll make 5 test channels, ensuring we keep track of which channel
|
|
||||||
// ID corresponds to a particular ChannelEdge.
|
|
||||||
const numChans = 5
|
|
||||||
startTime := time.Unix(1234, 0)
|
|
||||||
endTime := startTime
|
|
||||||
edges := make([]ChannelEdge, 0, numChans)
|
|
||||||
edgeQuery := make([]uint64, 0, numChans)
|
|
||||||
for i := 0; i < numChans; i++ {
|
|
||||||
txHash := sha256.Sum256([]byte{byte(i)})
|
|
||||||
op := wire.OutPoint{
|
|
||||||
Hash: txHash,
|
|
||||||
Index: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, chanID := createEdge(
|
|
||||||
uint32(i*10), 0, 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&channel); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updateTime := endTime
|
|
||||||
endTime = updateTime.Add(time.Second * 10)
|
|
||||||
|
|
||||||
edge1 := newEdgePolicy(
|
|
||||||
chanID.ToUint64(), op, db, updateTime.Unix(),
|
|
||||||
)
|
|
||||||
edge1.ChannelFlags = 0
|
|
||||||
edge1.Node = node2
|
|
||||||
edge1.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edge2 := newEdgePolicy(
|
|
||||||
chanID.ToUint64(), op, db, updateTime.Unix(),
|
|
||||||
)
|
|
||||||
edge2.ChannelFlags = 1
|
|
||||||
edge2.Node = node1
|
|
||||||
edge2.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edges = append(edges, ChannelEdge{
|
|
||||||
Info: &channel,
|
|
||||||
Policy1: edge1,
|
|
||||||
Policy2: edge2,
|
|
||||||
})
|
|
||||||
|
|
||||||
edgeQuery = append(edgeQuery, chanID.ToUint64())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add an additional edge that does not exist. The query should skip
|
|
||||||
// this channel and return only infos for the edges that exist.
|
|
||||||
edgeQuery = append(edgeQuery, 500)
|
|
||||||
|
|
||||||
// Add an another edge to the query that has been marked as a zombie
|
|
||||||
// edge. The query should also skip this channel.
|
|
||||||
zombieChan, zombieChanID := createEdge(
|
|
||||||
666, 0, 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
if err := graph.AddChannelEdge(&zombieChan); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
err = graph.DeleteChannelEdges(zombieChan.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to delete and mark edge zombie: %v", err)
|
|
||||||
}
|
|
||||||
edgeQuery = append(edgeQuery, zombieChanID.ToUint64())
|
|
||||||
|
|
||||||
// We'll now attempt to query for the range of channel ID's we just
|
|
||||||
// inserted into the database. We should get the exact same set of
|
|
||||||
// edges back.
|
|
||||||
resp, err := graph.FetchChanInfos(edgeQuery)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch chan edges: %v", err)
|
|
||||||
}
|
|
||||||
if len(resp) != len(edges) {
|
|
||||||
t.Fatalf("expected %v edges, instead got %v", len(edges),
|
|
||||||
len(resp))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(resp); i++ {
|
|
||||||
err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIncompleteChannelPolicies tests that a channel that only has a policy
|
|
||||||
// specified on one end is properly returned in ForEachChannel calls from
|
|
||||||
// both sides.
|
|
||||||
func TestIncompleteChannelPolicies(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// Create two nodes.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create channel between nodes.
|
|
||||||
txHash := sha256.Sum256([]byte{0})
|
|
||||||
op := wire.OutPoint{
|
|
||||||
Hash: txHash,
|
|
||||||
Index: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, chanID := createEdge(
|
|
||||||
uint32(0), 0, 0, 0, node1, node2,
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(&channel); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that channel is reported with unknown policies.
|
|
||||||
|
|
||||||
checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) {
|
|
||||||
calls := 0
|
|
||||||
node.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo,
|
|
||||||
outEdge, inEdge *ChannelEdgePolicy) error {
|
|
||||||
|
|
||||||
if !expectedOut && outEdge != nil {
|
|
||||||
t.Fatalf("Expected no outgoing policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
if expectedOut && outEdge == nil {
|
|
||||||
t.Fatalf("Expected an outgoing policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !expectedIn && inEdge != nil {
|
|
||||||
t.Fatalf("Expected no incoming policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
if expectedIn && inEdge == nil {
|
|
||||||
t.Fatalf("Expected an incoming policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
calls++
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if calls != 1 {
|
|
||||||
t.Fatalf("Expected only one callback call")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
checkPolicies(node2, false, false)
|
|
||||||
|
|
||||||
// Only create an edge policy for node1 and leave the policy for node2
|
|
||||||
// unknown.
|
|
||||||
updateTime := time.Unix(1234, 0)
|
|
||||||
|
|
||||||
edgePolicy := newEdgePolicy(
|
|
||||||
chanID.ToUint64(), op, db, updateTime.Unix(),
|
|
||||||
)
|
|
||||||
edgePolicy.ChannelFlags = 0
|
|
||||||
edgePolicy.Node = node2
|
|
||||||
edgePolicy.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
checkPolicies(node1, false, true)
|
|
||||||
checkPolicies(node2, true, false)
|
|
||||||
|
|
||||||
// Create second policy and assert that both policies are reported
|
|
||||||
// as present.
|
|
||||||
edgePolicy = newEdgePolicy(
|
|
||||||
chanID.ToUint64(), op, db, updateTime.Unix(),
|
|
||||||
)
|
|
||||||
edgePolicy.ChannelFlags = 1
|
|
||||||
edgePolicy.Node = node1
|
|
||||||
edgePolicy.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edgePolicy); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
checkPolicies(node1, true, true)
|
|
||||||
checkPolicies(node2, true, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestChannelEdgePruningUpdateIndexDeletion tests that once edges are deleted
|
|
||||||
// from the graph, then their entries within the update index are also cleaned
|
|
||||||
// up.
|
|
||||||
func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
sourceNode, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create source node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.SetSourceNode(sourceNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll first populate our graph with two nodes. All channels created
|
|
||||||
// below will be made between these two nodes.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the two nodes created, we'll now create a random channel, as
|
|
||||||
// well as two edges in the database with distinct update times.
|
|
||||||
edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2)
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to add edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db)
|
|
||||||
edge1.ChannelFlags = 0
|
|
||||||
edge1.Node = node1
|
|
||||||
edge1.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db)
|
|
||||||
edge2.ChannelFlags = 1
|
|
||||||
edge2.Node = node2
|
|
||||||
edge2.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkIndexTimestamps is a helper function that checks the edge update
|
|
||||||
// index only includes the given timestamps.
|
|
||||||
checkIndexTimestamps := func(timestamps ...uint64) {
|
|
||||||
timestampSet := make(map[uint64]struct{})
|
|
||||||
for _, t := range timestamps {
|
|
||||||
timestampSet[t] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := db.View(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket)
|
|
||||||
if edgeUpdateIndex == nil {
|
|
||||||
return ErrGraphNoEdgesFound
|
|
||||||
}
|
|
||||||
|
|
||||||
numEntries := edgeUpdateIndex.Stats().KeyN
|
|
||||||
expectedEntries := len(timestampSet)
|
|
||||||
if numEntries != expectedEntries {
|
|
||||||
return fmt.Errorf("expected %v entries in the "+
|
|
||||||
"update index, got %v", expectedEntries,
|
|
||||||
numEntries)
|
|
||||||
}
|
|
||||||
|
|
||||||
return edgeUpdateIndex.ForEach(func(k, _ []byte) error {
|
|
||||||
t := byteOrder.Uint64(k[:8])
|
|
||||||
if _, ok := timestampSet[t]; !ok {
|
|
||||||
return fmt.Errorf("found unexpected "+
|
|
||||||
"timestamp "+"%d", t)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// With both edges policies added, we'll make sure to check they exist
|
|
||||||
// within the edge update index.
|
|
||||||
checkIndexTimestamps(
|
|
||||||
uint64(edge1.LastUpdate.Unix()),
|
|
||||||
uint64(edge2.LastUpdate.Unix()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Now, we'll update the edge policies to ensure the old timestamps are
|
|
||||||
// removed from the update index.
|
|
||||||
edge1.ChannelFlags = 2
|
|
||||||
edge1.LastUpdate = time.Now()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
edge2.ChannelFlags = 3
|
|
||||||
edge2.LastUpdate = edge1.LastUpdate.Add(time.Hour)
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the policies updated, we should now be able to find their
|
|
||||||
// updated entries within the update index.
|
|
||||||
checkIndexTimestamps(
|
|
||||||
uint64(edge1.LastUpdate.Unix()),
|
|
||||||
uint64(edge2.LastUpdate.Unix()),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Now we'll prune the graph, removing the edges, and also the update
|
|
||||||
// index entries from the database all together.
|
|
||||||
var blockHash chainhash.Hash
|
|
||||||
copy(blockHash[:], bytes.Repeat([]byte{2}, 32))
|
|
||||||
_, err = graph.PruneGraph(
|
|
||||||
[]*wire.OutPoint{&edgeInfo.ChannelPoint}, &blockHash, 101,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to prune graph: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, we'll check the database state one last time to conclude
|
|
||||||
// that we should no longer be able to locate _any_ entries within the
|
|
||||||
// edge update index.
|
|
||||||
checkIndexTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPruneGraphNodes tests that unconnected vertexes are pruned via the
|
|
||||||
// PruneSyncState method.
|
|
||||||
func TestPruneGraphNodes(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll start off by inserting our source node, to ensure that it's
|
|
||||||
// the only node left after we prune the graph.
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
sourceNode, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create source node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.SetSourceNode(sourceNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the source node inserted, we'll now add three nodes to the
|
|
||||||
// channel graph, at the end of the scenario, only two of these nodes
|
|
||||||
// should still be in the graph.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node3, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node3); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now add a new edge to the graph, but only actually advertise
|
|
||||||
// the edge of *one* of the nodes.
|
|
||||||
edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2)
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to add edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now insert an advertised edge, but it'll only be the edge that
|
|
||||||
// points from the first to the second node.
|
|
||||||
edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db)
|
|
||||||
edge1.ChannelFlags = 0
|
|
||||||
edge1.Node = node1
|
|
||||||
edge1.SigBytes = testSig.Serialize()
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now initiate a around of graph pruning.
|
|
||||||
if err := graph.PruneGraphNodes(); err != nil {
|
|
||||||
t.Fatalf("unable to prune graph nodes: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// At this point, there should be 3 nodes left in the graph still: the
|
|
||||||
// source node (which can't be pruned), and node 1+2. Nodes 1 and two
|
|
||||||
// should still be left in the graph as there's half of an advertised
|
|
||||||
// edge between them.
|
|
||||||
assertNumNodes(t, graph, 3)
|
|
||||||
|
|
||||||
// Finally, we'll ensure that node3, the only fully unconnected node as
|
|
||||||
// properly deleted from the graph and not another node in its place.
|
|
||||||
node3Pub, err := node3.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch the pubkey of node3: %v", err)
|
|
||||||
}
|
|
||||||
if _, err := graph.FetchLightningNode(node3Pub); err == nil {
|
|
||||||
t.Fatalf("node 3 should have been deleted!")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddChannelEdgeShellNodes tests that when we attempt to add a ChannelEdge
|
|
||||||
// to the graph, one or both of the nodes the edge involves aren't found in the
|
|
||||||
// database, then shell edges are created for each node if needed.
|
|
||||||
func TestAddChannelEdgeShellNodes(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// To start, we'll create two nodes, and only add one of them to the
|
|
||||||
// channel graph.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now create an edge between the two nodes, as a result, node2
|
|
||||||
// should be inserted into the database as a shell node.
|
|
||||||
edgeInfo, _ := createEdge(100, 0, 0, 0, node1, node2)
|
|
||||||
if err := graph.AddChannelEdge(&edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to add edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
node1Pub, err := node1.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to parse node 1 pub: %v", err)
|
|
||||||
}
|
|
||||||
node2Pub, err := node2.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to parse node 2 pub: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that node1 was inserted as a full node, while node2 only has
|
|
||||||
// a shell node present.
|
|
||||||
node1, err = graph.FetchLightningNode(node1Pub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch node1: %v", err)
|
|
||||||
}
|
|
||||||
if !node1.HaveNodeAnnouncement {
|
|
||||||
t.Fatalf("have shell announcement for node1, shouldn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
node2, err = graph.FetchLightningNode(node2Pub)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch node2: %v", err)
|
|
||||||
}
|
|
||||||
if node2.HaveNodeAnnouncement {
|
|
||||||
t.Fatalf("should have shell announcement for node2, but is full")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNodePruningUpdateIndexDeletion tests that once a node has been removed
|
|
||||||
// from the channel graph, we also remove the entry from the update index as
|
|
||||||
// well.
|
|
||||||
func TestNodePruningUpdateIndexDeletion(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'll first populate our graph with a single node that will be
|
|
||||||
// removed shortly.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll confirm that we can retrieve the node using
|
|
||||||
// NodeUpdatesInHorizon, using a time that's slightly beyond the last
|
|
||||||
// update time of our test node.
|
|
||||||
startTime := time.Unix(9, 0)
|
|
||||||
endTime := node1.LastUpdate.Add(time.Minute)
|
|
||||||
nodesInHorizon, err := graph.NodeUpdatesInHorizon(startTime, endTime)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch nodes in horizon: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should only have a single node, and that node should exactly
|
|
||||||
// match the node we just inserted.
|
|
||||||
if len(nodesInHorizon) != 1 {
|
|
||||||
t.Fatalf("should have 1 nodes instead have: %v",
|
|
||||||
len(nodesInHorizon))
|
|
||||||
}
|
|
||||||
if err := compareNodes(node1, &nodesInHorizon[0]); err != nil {
|
|
||||||
t.Fatalf("nodes don't match: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now delete the node from the graph, this should result in it
|
|
||||||
// being removed from the update index as well.
|
|
||||||
nodePub, _ := node1.PubKey()
|
|
||||||
if err := graph.DeleteLightningNode(nodePub); err != nil {
|
|
||||||
t.Fatalf("unable to delete node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now that the node has been deleted, we'll again query the nodes in
|
|
||||||
// the horizon. This time we should have no nodes at all.
|
|
||||||
nodesInHorizon, err = graph.NodeUpdatesInHorizon(startTime, endTime)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch nodes in horizon: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(nodesInHorizon) != 0 {
|
|
||||||
t.Fatalf("should have zero nodes instead have: %v",
|
|
||||||
len(nodesInHorizon))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestNodeIsPublic ensures that we properly detect nodes that are seen as
|
|
||||||
// public within the network graph.
|
|
||||||
func TestNodeIsPublic(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// We'll start off the test by creating a small network of 3
|
|
||||||
// participants with the following graph:
|
|
||||||
//
|
|
||||||
// Alice <-> Bob <-> Carol
|
|
||||||
//
|
|
||||||
// We'll need to create a separate database and channel graph for each
|
|
||||||
// participant to replicate real-world scenarios (private edges being in
|
|
||||||
// some graphs but not others, etc.).
|
|
||||||
aliceDB, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
aliceNode, err := createTestVertex(aliceDB)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
aliceGraph := aliceDB.ChannelGraph()
|
|
||||||
if err := aliceGraph.SetSourceNode(aliceNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bobDB, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
bobNode, err := createTestVertex(bobDB)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
bobGraph := bobDB.ChannelGraph()
|
|
||||||
if err := bobGraph.SetSourceNode(bobNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
carolDB, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
carolNode, err := createTestVertex(carolDB)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
carolGraph := carolDB.ChannelGraph()
|
|
||||||
if err := carolGraph.SetSourceNode(carolNode); err != nil {
|
|
||||||
t.Fatalf("unable to set source node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
aliceBobEdge, _ := createEdge(10, 0, 0, 0, aliceNode, bobNode)
|
|
||||||
bobCarolEdge, _ := createEdge(10, 1, 0, 1, bobNode, carolNode)
|
|
||||||
|
|
||||||
// After creating all of our nodes and edges, we'll add them to each
|
|
||||||
// participant's graph.
|
|
||||||
nodes := []*LightningNode{aliceNode, bobNode, carolNode}
|
|
||||||
edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge}
|
|
||||||
dbs := []*DB{aliceDB, bobDB, carolDB}
|
|
||||||
graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph}
|
|
||||||
for i, graph := range graphs {
|
|
||||||
for _, node := range nodes {
|
|
||||||
node.db = dbs[i]
|
|
||||||
if err := graph.AddLightningNode(node); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, edge := range edges {
|
|
||||||
edge.db = dbs[i]
|
|
||||||
if err := graph.AddChannelEdge(edge); err != nil {
|
|
||||||
t.Fatalf("unable to add edge: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkNodes is a helper closure that will be used to assert that the
|
|
||||||
// given nodes are seen as public/private within the given graphs.
|
|
||||||
checkNodes := func(nodes []*LightningNode, graphs []*ChannelGraph,
|
|
||||||
public bool) {
|
|
||||||
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for _, node := range nodes {
|
|
||||||
for _, graph := range graphs {
|
|
||||||
isPublic, err := graph.IsPublicNode(node.PubKeyBytes)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to determine if pivot "+
|
|
||||||
"is public: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case isPublic && !public:
|
|
||||||
t.Fatalf("expected %x to be private",
|
|
||||||
node.PubKeyBytes)
|
|
||||||
case !isPublic && public:
|
|
||||||
t.Fatalf("expected %x to be public",
|
|
||||||
node.PubKeyBytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Due to the way the edges were set up above, we'll make sure each node
|
|
||||||
// can correctly determine that every other node is public.
|
|
||||||
checkNodes(nodes, graphs, true)
|
|
||||||
|
|
||||||
// Now, we'll remove the edge between Alice and Bob from everyone's
|
|
||||||
// graph. This will make Alice be seen as a private node as it no longer
|
|
||||||
// has any advertised edges.
|
|
||||||
for _, graph := range graphs {
|
|
||||||
err := graph.DeleteChannelEdges(aliceBobEdge.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to remove edge: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
checkNodes(
|
|
||||||
[]*LightningNode{aliceNode},
|
|
||||||
[]*ChannelGraph{bobGraph, carolGraph},
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
|
|
||||||
// We'll also make the edge between Bob and Carol private. Within Bob's
|
|
||||||
// and Carol's graph, the edge will exist, but it will not have a proof
|
|
||||||
// that allows it to be advertised. Within Alice's graph, we'll
|
|
||||||
// completely remove the edge as it is not possible for her to know of
|
|
||||||
// it without it being advertised.
|
|
||||||
for i, graph := range graphs {
|
|
||||||
err := graph.DeleteChannelEdges(bobCarolEdge.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to remove edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if graph == aliceGraph {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
bobCarolEdge.AuthProof = nil
|
|
||||||
bobCarolEdge.db = dbs[i]
|
|
||||||
if err := graph.AddChannelEdge(&bobCarolEdge); err != nil {
|
|
||||||
t.Fatalf("unable to add edge: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the modifications above, Bob should now be seen as a private
|
|
||||||
// node from both Alice's and Carol's perspective.
|
|
||||||
checkNodes(
|
|
||||||
[]*LightningNode{bobNode},
|
|
||||||
[]*ChannelGraph{aliceGraph, carolGraph},
|
|
||||||
false,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDisabledChannelIDs ensures that the disabled channels within the
|
|
||||||
// disabledEdgePolicyBucket are managed properly and the list returned from
|
|
||||||
// DisabledChannelIDs is correct.
|
|
||||||
func TestDisabledChannelIDs(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// Create first node and add it to the graph.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create second node and add it to the graph.
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adding a new channel edge to the graph.
|
|
||||||
edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2)
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := graph.AddChannelEdge(edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure no disabled channels exist in the bucket on start.
|
|
||||||
disabledChanIds, err := graph.DisabledChannelIDs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get disabled channel ids: %v", err)
|
|
||||||
}
|
|
||||||
if len(disabledChanIds) > 0 {
|
|
||||||
t.Fatalf("expected empty disabled channels, got %v disabled channels",
|
|
||||||
len(disabledChanIds))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add one disabled policy and ensure the channel is still not in the
|
|
||||||
// disabled list.
|
|
||||||
edge1.ChannelFlags |= lnwire.ChanUpdateDisabled
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
disabledChanIds, err = graph.DisabledChannelIDs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get disabled channel ids: %v", err)
|
|
||||||
}
|
|
||||||
if len(disabledChanIds) > 0 {
|
|
||||||
t.Fatalf("expected empty disabled channels, got %v disabled channels",
|
|
||||||
len(disabledChanIds))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add second disabled policy and ensure the channel is now in the
|
|
||||||
// disabled list.
|
|
||||||
edge2.ChannelFlags |= lnwire.ChanUpdateDisabled
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
disabledChanIds, err = graph.DisabledChannelIDs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get disabled channel ids: %v", err)
|
|
||||||
}
|
|
||||||
if len(disabledChanIds) != 1 || disabledChanIds[0] != edgeInfo.ChannelID {
|
|
||||||
t.Fatalf("expected disabled channel with id %v, "+
|
|
||||||
"got %v", edgeInfo.ChannelID, disabledChanIds)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the channel edge and ensure it is removed from the disabled list.
|
|
||||||
if err = graph.DeleteChannelEdges(edgeInfo.ChannelID); err != nil {
|
|
||||||
t.Fatalf("unable to delete channel edge: %v", err)
|
|
||||||
}
|
|
||||||
disabledChanIds, err = graph.DisabledChannelIDs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get disabled channel ids: %v", err)
|
|
||||||
}
|
|
||||||
if len(disabledChanIds) > 0 {
|
|
||||||
t.Fatalf("expected empty disabled channels, got %v disabled channels",
|
|
||||||
len(disabledChanIds))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in
|
|
||||||
// the DB that indicates that it should support the htlc_maximum_value_msat
|
|
||||||
// field, but it is not part of the opaque data, then we'll handle it as it is
|
|
||||||
// unknown. It also checks that we are correctly able to overwrite it when we
|
|
||||||
// receive the proper update.
|
|
||||||
func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
// We'd like to test the update of edges inserted into the database, so
|
|
||||||
// we create two vertexes to connect.
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddLightningNode(node1); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2)
|
|
||||||
if err := graph.AddLightningNode(node2); err != nil {
|
|
||||||
t.Fatalf("unable to add node: %v", err)
|
|
||||||
}
|
|
||||||
if err := graph.AddChannelEdge(edgeInfo); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chanID := edgeInfo.ChannelID
|
|
||||||
from := edge2.Node.PubKeyBytes[:]
|
|
||||||
to := edge1.Node.PubKeyBytes[:]
|
|
||||||
|
|
||||||
// We'll remove the no max_htlc field from the first edge policy, and
|
|
||||||
// all other opaque data, and serialize it.
|
|
||||||
edge1.MessageFlags = 0
|
|
||||||
edge1.ExtraOpaqueData = nil
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
err = serializeChanEdgePolicy(&b, edge1, to)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to serialize policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the max_htlc field. The extra bytes added to the serialization
|
|
||||||
// will be the opaque data containing the serialized field.
|
|
||||||
edge1.MessageFlags = lnwire.ChanUpdateOptionMaxHtlc
|
|
||||||
edge1.MaxHTLC = 13928598
|
|
||||||
var b2 bytes.Buffer
|
|
||||||
err = serializeChanEdgePolicy(&b2, edge1, to)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to serialize policy")
|
|
||||||
}
|
|
||||||
|
|
||||||
withMaxHtlc := b2.Bytes()
|
|
||||||
|
|
||||||
// Remove the opaque data from the serialization.
|
|
||||||
stripped := withMaxHtlc[:len(b.Bytes())]
|
|
||||||
|
|
||||||
// Attempting to deserialize these bytes should return an error.
|
|
||||||
r := bytes.NewReader(stripped)
|
|
||||||
err = db.View(func(tx *bbolt.Tx) error {
|
|
||||||
nodes := tx.Bucket(nodeBucket)
|
|
||||||
if nodes == nil {
|
|
||||||
return ErrGraphNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = deserializeChanEdgePolicy(r, nodes)
|
|
||||||
if err != ErrEdgePolicyOptionalFieldNotFound {
|
|
||||||
t.Fatalf("expected "+
|
|
||||||
"ErrEdgePolicyOptionalFieldNotFound, got %v",
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error reading db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put the stripped bytes in the DB.
|
|
||||||
err = db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
edges := tx.Bucket(edgeBucket)
|
|
||||||
if edges == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
edgeIndex := edges.Bucket(edgeIndexBucket)
|
|
||||||
if edgeIndex == nil {
|
|
||||||
return ErrEdgeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
var edgeKey [33 + 8]byte
|
|
||||||
copy(edgeKey[:], from)
|
|
||||||
byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID)
|
|
||||||
|
|
||||||
var scratch [8]byte
|
|
||||||
var indexKey [8 + 8]byte
|
|
||||||
copy(indexKey[:], scratch[:])
|
|
||||||
byteOrder.PutUint64(indexKey[8:], edge1.ChannelID)
|
|
||||||
|
|
||||||
updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updateIndex.Put(indexKey[:], nil); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return edges.Put(edgeKey[:], stripped)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error writing db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// And add the second, unmodified edge.
|
|
||||||
if err := graph.UpdateEdgePolicy(edge2); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to fetch the edge and policies from the DB. Since the policy
|
|
||||||
// we added is invalid according to the new format, it should be as we
|
|
||||||
// are not aware of the policy (indicated by the policy returned being
|
|
||||||
// nil)
|
|
||||||
dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel by ID: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The first edge should have a nil-policy returned
|
|
||||||
if dbEdge1 != nil {
|
|
||||||
t.Fatalf("expected db edge to be nil")
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge2, edge2); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo)
|
|
||||||
|
|
||||||
// Now add the original, unmodified edge policy, and make sure the edge
|
|
||||||
// policies then become fully populated.
|
|
||||||
if err := graph.UpdateEdgePolicy(edge1); err != nil {
|
|
||||||
t.Fatalf("unable to update edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByID(chanID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch channel by ID: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge1, edge1); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
if err := compareEdgePolicies(dbEdge2, edge2); err != nil {
|
|
||||||
t.Fatalf("edge doesn't match: %v", err)
|
|
||||||
}
|
|
||||||
assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertNumZombies queries the provided ChannelGraph for NumZombies, and
|
|
||||||
// asserts that the returned number is equal to expZombies.
|
|
||||||
func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
numZombies, err := graph.NumZombies()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query number of zombies: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if numZombies != expZombies {
|
|
||||||
t.Fatalf("expected %d zombies, found %d",
|
|
||||||
expZombies, numZombies)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGraphZombieIndex ensures that we can mark edges correctly as zombie/live.
|
|
||||||
func TestGraphZombieIndex(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// We'll start by creating our test graph along with a test edge.
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test database: %v", err)
|
|
||||||
}
|
|
||||||
graph := db.ChannelGraph()
|
|
||||||
|
|
||||||
node1, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test vertex: %v", err)
|
|
||||||
}
|
|
||||||
node2, err := createTestVertex(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test vertex: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Swap the nodes if the second's pubkey is smaller than the first.
|
|
||||||
// Without this, the comparisons at the end will fail probabilistically.
|
|
||||||
if bytes.Compare(node2.PubKeyBytes[:], node1.PubKeyBytes[:]) < 0 {
|
|
||||||
node1, node2 = node2, node1
|
|
||||||
}
|
|
||||||
|
|
||||||
edge, _, _ := createChannelEdge(db, node1, node2)
|
|
||||||
if err := graph.AddChannelEdge(edge); err != nil {
|
|
||||||
t.Fatalf("unable to create channel edge: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since the edge is known the graph and it isn't a zombie, IsZombieEdge
|
|
||||||
// should not report the channel as a zombie.
|
|
||||||
isZombie, _, _ := graph.IsZombieEdge(edge.ChannelID)
|
|
||||||
if isZombie {
|
|
||||||
t.Fatal("expected edge to not be marked as zombie")
|
|
||||||
}
|
|
||||||
assertNumZombies(t, graph, 0)
|
|
||||||
|
|
||||||
// If we delete the edge and mark it as a zombie, then we should expect
|
|
||||||
// to see it within the index.
|
|
||||||
err = graph.DeleteChannelEdges(edge.ChannelID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to mark edge as zombie: %v", err)
|
|
||||||
}
|
|
||||||
isZombie, pubKey1, pubKey2 := graph.IsZombieEdge(edge.ChannelID)
|
|
||||||
if !isZombie {
|
|
||||||
t.Fatal("expected edge to be marked as zombie")
|
|
||||||
}
|
|
||||||
if pubKey1 != node1.PubKeyBytes {
|
|
||||||
t.Fatalf("expected pubKey1 %x, got %x", node1.PubKeyBytes,
|
|
||||||
pubKey1)
|
|
||||||
}
|
|
||||||
if pubKey2 != node2.PubKeyBytes {
|
|
||||||
t.Fatalf("expected pubKey2 %x, got %x", node2.PubKeyBytes,
|
|
||||||
pubKey2)
|
|
||||||
}
|
|
||||||
assertNumZombies(t, graph, 1)
|
|
||||||
|
|
||||||
// Similarly, if we mark the same edge as live, we should no longer see
|
|
||||||
// it within the index.
|
|
||||||
if err := graph.MarkEdgeLive(edge.ChannelID); err != nil {
|
|
||||||
t.Fatalf("unable to mark edge as live: %v", err)
|
|
||||||
}
|
|
||||||
isZombie, _, _ = graph.IsZombieEdge(edge.ChannelID)
|
|
||||||
if isZombie {
|
|
||||||
t.Fatal("expected edge to not be marked as zombie")
|
|
||||||
}
|
|
||||||
assertNumZombies(t, graph, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// compareNodes is used to compare two LightningNodes while excluding the
|
|
||||||
// Features struct, which cannot be compared as the semantics for reserializing
|
|
||||||
// the featuresMap have not been defined.
|
|
||||||
func compareNodes(a, b *LightningNode) error {
|
|
||||||
if a.LastUpdate != b.LastUpdate {
|
|
||||||
return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+
|
|
||||||
"got %v", a.LastUpdate, b.LastUpdate)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.Addresses, b.Addresses) {
|
|
||||||
return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.Addresses, b.Addresses)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.PubKeyBytes, b.PubKeyBytes) {
|
|
||||||
return fmt.Errorf("PubKey doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.PubKeyBytes, b.PubKeyBytes)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.Color, b.Color) {
|
|
||||||
return fmt.Errorf("Color doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.Color, b.Color)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.Alias, b.Alias) {
|
|
||||||
return fmt.Errorf("Alias doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.Alias, b.Alias)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.db, b.db) {
|
|
||||||
return fmt.Errorf("db doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.db, b.db)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) {
|
|
||||||
return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) {
|
|
||||||
return fmt.Errorf("extra data doesn't match: %v vs %v",
|
|
||||||
a.ExtraOpaqueData, b.ExtraOpaqueData)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// compareEdgePolicies is used to compare two ChannelEdgePolices using
|
|
||||||
// compareNodes, so as to exclude comparisons of the Nodes' Features struct.
|
|
||||||
func compareEdgePolicies(a, b *ChannelEdgePolicy) error {
|
|
||||||
if a.ChannelID != b.ChannelID {
|
|
||||||
return fmt.Errorf("ChannelID doesn't match: expected %v, "+
|
|
||||||
"got %v", a.ChannelID, b.ChannelID)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) {
|
|
||||||
return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.LastUpdate, b.LastUpdate)
|
|
||||||
}
|
|
||||||
if a.MessageFlags != b.MessageFlags {
|
|
||||||
return fmt.Errorf("MessageFlags doesn't match: expected %v, "+
|
|
||||||
"got %v", a.MessageFlags, b.MessageFlags)
|
|
||||||
}
|
|
||||||
if a.ChannelFlags != b.ChannelFlags {
|
|
||||||
return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+
|
|
||||||
"got %v", a.ChannelFlags, b.ChannelFlags)
|
|
||||||
}
|
|
||||||
if a.TimeLockDelta != b.TimeLockDelta {
|
|
||||||
return fmt.Errorf("TimeLockDelta doesn't match: expected %v, "+
|
|
||||||
"got %v", a.TimeLockDelta, b.TimeLockDelta)
|
|
||||||
}
|
|
||||||
if a.MinHTLC != b.MinHTLC {
|
|
||||||
return fmt.Errorf("MinHTLC doesn't match: expected %v, "+
|
|
||||||
"got %v", a.MinHTLC, b.MinHTLC)
|
|
||||||
}
|
|
||||||
if a.MaxHTLC != b.MaxHTLC {
|
|
||||||
return fmt.Errorf("MaxHTLC doesn't match: expected %v, "+
|
|
||||||
"got %v", a.MaxHTLC, b.MaxHTLC)
|
|
||||||
}
|
|
||||||
if a.FeeBaseMSat != b.FeeBaseMSat {
|
|
||||||
return fmt.Errorf("FeeBaseMSat doesn't match: expected %v, "+
|
|
||||||
"got %v", a.FeeBaseMSat, b.FeeBaseMSat)
|
|
||||||
}
|
|
||||||
if a.FeeProportionalMillionths != b.FeeProportionalMillionths {
|
|
||||||
return fmt.Errorf("FeeProportionalMillionths doesn't match: "+
|
|
||||||
"expected %v, got %v", a.FeeProportionalMillionths,
|
|
||||||
b.FeeProportionalMillionths)
|
|
||||||
}
|
|
||||||
if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) {
|
|
||||||
return fmt.Errorf("extra data doesn't match: %v vs %v",
|
|
||||||
a.ExtraOpaqueData, b.ExtraOpaqueData)
|
|
||||||
}
|
|
||||||
if err := compareNodes(a.Node, b.Node); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(a.db, b.db) {
|
|
||||||
return fmt.Errorf("db doesn't match: expected %#v, \n "+
|
|
||||||
"got %#v", a.db, b.db)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestLightningNodeSigVerifcation checks that we can use the LightningNode's
|
|
||||||
// pubkey to verify signatures.
|
|
||||||
func TestLightningNodeSigVerification(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// Create some dummy data to sign.
|
|
||||||
var data [32]byte
|
|
||||||
if _, err := prand.Read(data[:]); err != nil {
|
|
||||||
t.Fatalf("unable to read prand: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create private key and sign the data with it.
|
|
||||||
priv, err := btcec.NewPrivateKey(btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to crete priv key: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sign, err := priv.Sign(data[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to sign: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sanity check that the signature checks out.
|
|
||||||
if !sign.Verify(data[:], priv.PubKey()) {
|
|
||||||
t.Fatalf("signature doesn't check out")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a LightningNode from the same private key.
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
node, err := createLightningNode(db, priv)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// And finally check that we can verify the same signature from the
|
|
||||||
// pubkey returned from the lightning node.
|
|
||||||
nodePub, err := node.PubKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to get pubkey: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !sign.Verify(data[:], nodePub) {
|
|
||||||
t.Fatalf("unable to verify sig")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestComputeFee tests fee calculation based on both in- and outgoing amt.
|
|
||||||
func TestComputeFee(t *testing.T) {
|
|
||||||
var (
|
|
||||||
policy = ChannelEdgePolicy{
|
|
||||||
FeeBaseMSat: 10000,
|
|
||||||
FeeProportionalMillionths: 30000,
|
|
||||||
}
|
|
||||||
outgoingAmt = lnwire.MilliSatoshi(1000000)
|
|
||||||
expectedFee = lnwire.MilliSatoshi(40000)
|
|
||||||
)
|
|
||||||
|
|
||||||
fee := policy.ComputeFee(outgoingAmt)
|
|
||||||
if fee != expectedFee {
|
|
||||||
t.Fatalf("expected fee %v, got %v", expectedFee, fee)
|
|
||||||
}
|
|
||||||
|
|
||||||
fwdFee := policy.ComputeFeeFromIncoming(outgoingAmt + fee)
|
|
||||||
if fwdFee != expectedFee {
|
|
||||||
t.Fatalf("expected fee %v, but got %v", fee, fwdFee)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,694 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
|
||||||
|
|
||||||
func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
|
|
||||||
var pre [32]byte
|
|
||||||
if _, err := rand.Read(pre[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i := &Invoice{
|
|
||||||
// Use single second precision to avoid false positive test
|
|
||||||
// failures due to the monotonic time component.
|
|
||||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
|
||||||
Terms: ContractTerm{
|
|
||||||
PaymentPreimage: pre,
|
|
||||||
Value: value,
|
|
||||||
},
|
|
||||||
Htlcs: map[CircuitKey]*InvoiceHTLC{},
|
|
||||||
Expiry: 4000,
|
|
||||||
}
|
|
||||||
i.Memo = []byte("memo")
|
|
||||||
i.Receipt = []byte("receipt")
|
|
||||||
|
|
||||||
// Create a random byte slice of MaxPaymentRequestSize bytes to be used
|
|
||||||
// as a dummy paymentrequest, and determine if it should be set based
|
|
||||||
// on one of the random bytes.
|
|
||||||
var r [MaxPaymentRequestSize]byte
|
|
||||||
if _, err := rand.Read(r[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if r[0]&1 == 0 {
|
|
||||||
i.PaymentRequest = r[:]
|
|
||||||
} else {
|
|
||||||
i.PaymentRequest = []byte("")
|
|
||||||
}
|
|
||||||
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvoiceWorkflow(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a fake invoice which we'll use several times in the tests
|
|
||||||
// below.
|
|
||||||
fakeInvoice := &Invoice{
|
|
||||||
// Use single second precision to avoid false positive test
|
|
||||||
// failures due to the monotonic time component.
|
|
||||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
|
||||||
Htlcs: map[CircuitKey]*InvoiceHTLC{},
|
|
||||||
}
|
|
||||||
fakeInvoice.Memo = []byte("memo")
|
|
||||||
fakeInvoice.Receipt = []byte("receipt")
|
|
||||||
fakeInvoice.PaymentRequest = []byte("")
|
|
||||||
copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:])
|
|
||||||
fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
|
|
||||||
|
|
||||||
paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash()
|
|
||||||
|
|
||||||
// Add the invoice to the database, this should succeed as there aren't
|
|
||||||
// any existing invoices within the database with the same payment
|
|
||||||
// hash.
|
|
||||||
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil {
|
|
||||||
t.Fatalf("unable to find invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to retrieve the invoice which was just added to the
|
|
||||||
// database. It should be found, and the invoice returned should be
|
|
||||||
// identical to the one created above.
|
|
||||||
dbInvoice, err := db.LookupInvoice(paymentHash)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to find invoice: %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(*fakeInvoice, dbInvoice) {
|
|
||||||
t.Fatalf("invoice fetched from db doesn't match original %v vs %v",
|
|
||||||
spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The add index of the invoice retrieved from the database should now
|
|
||||||
// be fully populated. As this is the first index written to the DB,
|
|
||||||
// the addIndex should be 1.
|
|
||||||
if dbInvoice.AddIndex != 1 {
|
|
||||||
t.Fatalf("wrong add index: expected %v, got %v", 1,
|
|
||||||
dbInvoice.AddIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Settle the invoice, the version retrieved from the database should
|
|
||||||
// now have the settled bit toggle to true and a non-default
|
|
||||||
// SettledDate
|
|
||||||
payAmt := fakeInvoice.Terms.Value * 2
|
|
||||||
_, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
|
||||||
}
|
|
||||||
dbInvoice2, err := db.LookupInvoice(paymentHash)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch invoice: %v", err)
|
|
||||||
}
|
|
||||||
if dbInvoice2.Terms.State != ContractSettled {
|
|
||||||
t.Fatalf("invoice should now be settled but isn't")
|
|
||||||
}
|
|
||||||
if dbInvoice2.SettleDate.IsZero() {
|
|
||||||
t.Fatalf("invoice should have non-zero SettledDate but isn't")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Our 2x payment should be reflected, and also the settle index of 1
|
|
||||||
// should also have been committed for this index.
|
|
||||||
if dbInvoice2.AmtPaid != payAmt {
|
|
||||||
t.Fatalf("wrong amt paid: expected %v, got %v", payAmt,
|
|
||||||
dbInvoice2.AmtPaid)
|
|
||||||
}
|
|
||||||
if dbInvoice2.SettleIndex != 1 {
|
|
||||||
t.Fatalf("wrong settle index: expected %v, got %v", 1,
|
|
||||||
dbInvoice2.SettleIndex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to insert generated above again, this should fail as
|
|
||||||
// duplicates are rejected by the processing logic.
|
|
||||||
if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice {
|
|
||||||
t.Fatalf("invoice insertion should fail due to duplication, "+
|
|
||||||
"instead %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to look up a non-existent invoice, this should also fail but
|
|
||||||
// with a "not found" error.
|
|
||||||
var fakeHash [32]byte
|
|
||||||
if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound {
|
|
||||||
t.Fatalf("lookup should have failed, instead %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add 10 random invoices.
|
|
||||||
const numInvoices = 10
|
|
||||||
amt := lnwire.NewMSatFromSatoshis(1000)
|
|
||||||
invoices := make([]*Invoice, numInvoices+1)
|
|
||||||
invoices[0] = &dbInvoice2
|
|
||||||
for i := 1; i < len(invoices)-1; i++ {
|
|
||||||
invoice, err := randInvoice(amt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := invoice.Terms.PaymentPreimage.Hash()
|
|
||||||
if _, err := db.AddInvoice(invoice, hash); err != nil {
|
|
||||||
t.Fatalf("unable to add invoice %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
invoices[i] = invoice
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform a scan to collect all the active invoices.
|
|
||||||
dbInvoices, err := db.FetchAllInvoices(false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch all invoices: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The retrieve list of invoices should be identical as since we're
|
|
||||||
// using big endian, the invoices should be retrieved in ascending
|
|
||||||
// order (and the primary key should be incremented with each
|
|
||||||
// insertion).
|
|
||||||
for i := 0; i < len(invoices)-1; i++ {
|
|
||||||
if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) {
|
|
||||||
t.Fatalf("retrieved invoices don't match %v vs %v",
|
|
||||||
spew.Sdump(invoices[i]),
|
|
||||||
spew.Sdump(dbInvoices[i]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as
|
|
||||||
// settled invoices are added to the database are properly placed in the add
|
|
||||||
// add or settle index which serves as an event time series.
|
|
||||||
func TestInvoiceAddTimeSeries(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll start off by creating 20 random invoices, and inserting them
|
|
||||||
// into the database.
|
|
||||||
const numInvoices = 20
|
|
||||||
amt := lnwire.NewMSatFromSatoshis(1000)
|
|
||||||
invoices := make([]Invoice, numInvoices)
|
|
||||||
for i := 0; i < len(invoices); i++ {
|
|
||||||
invoice, err := randInvoice(amt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
|
||||||
|
|
||||||
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
|
|
||||||
t.Fatalf("unable to add invoice %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
invoices[i] = *invoice
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the invoices constructed, we'll now create a series of queries
|
|
||||||
// that we'll use to assert expected return values of
|
|
||||||
// InvoicesAddedSince.
|
|
||||||
addQueries := []struct {
|
|
||||||
sinceAddIndex uint64
|
|
||||||
|
|
||||||
resp []Invoice
|
|
||||||
}{
|
|
||||||
// If we specify a value of zero, we shouldn't get any invoices
|
|
||||||
// back.
|
|
||||||
{
|
|
||||||
sinceAddIndex: 0,
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we specify a value well beyond the number of inserted
|
|
||||||
// invoices, we shouldn't get any invoices back.
|
|
||||||
{
|
|
||||||
sinceAddIndex: 99999999,
|
|
||||||
},
|
|
||||||
|
|
||||||
// Using an index of 1 should result in all values, but the
|
|
||||||
// first one being returned.
|
|
||||||
{
|
|
||||||
sinceAddIndex: 1,
|
|
||||||
resp: invoices[1:],
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we use an index of 10, then we should retrieve the
|
|
||||||
// reaming 10 invoices.
|
|
||||||
{
|
|
||||||
sinceAddIndex: 10,
|
|
||||||
resp: invoices[10:],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, query := range addQueries {
|
|
||||||
resp, err := db.InvoicesAddedSince(query.sinceAddIndex)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(query.resp, resp) {
|
|
||||||
t.Fatalf("test #%v: expected %v, got %v", i,
|
|
||||||
spew.Sdump(query.resp), spew.Sdump(resp))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now only settle the latter half of each of those invoices.
|
|
||||||
for i := 10; i < len(invoices); i++ {
|
|
||||||
invoice := &invoices[i]
|
|
||||||
|
|
||||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
|
||||||
|
|
||||||
_, err := db.UpdateInvoice(
|
|
||||||
paymentHash, getUpdateInvoice(0),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
invoices, err = db.FetchAllInvoices(false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch invoices: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll slice off the first 10 invoices, as we only settled the last
|
|
||||||
// 10.
|
|
||||||
invoices = invoices[10:]
|
|
||||||
|
|
||||||
// We'll now prepare an additional set of queries to ensure the settle
|
|
||||||
// time series has properly been maintained in the database.
|
|
||||||
settleQueries := []struct {
|
|
||||||
sinceSettleIndex uint64
|
|
||||||
|
|
||||||
resp []Invoice
|
|
||||||
}{
|
|
||||||
// If we specify a value of zero, we shouldn't get any settled
|
|
||||||
// invoices back.
|
|
||||||
{
|
|
||||||
sinceSettleIndex: 0,
|
|
||||||
},
|
|
||||||
|
|
||||||
// If we specify a value well beyond the number of settled
|
|
||||||
// invoices, we shouldn't get any invoices back.
|
|
||||||
{
|
|
||||||
sinceSettleIndex: 99999999,
|
|
||||||
},
|
|
||||||
|
|
||||||
// Using an index of 1 should result in the final 10 invoices
|
|
||||||
// being returned, as we only settled those.
|
|
||||||
{
|
|
||||||
sinceSettleIndex: 1,
|
|
||||||
resp: invoices[1:],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, query := range settleQueries {
|
|
||||||
resp, err := db.InvoicesSettledSince(query.sinceSettleIndex)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(query.resp, resp) {
|
|
||||||
t.Fatalf("test #%v: expected %v, got %v", i,
|
|
||||||
spew.Sdump(query.resp), spew.Sdump(resp))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it
|
|
||||||
// twice, then the second time we also receive the invoice that we settled as a
|
|
||||||
// return argument.
|
|
||||||
func TestDuplicateSettleInvoice(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
db.now = func() time.Time { return time.Unix(1, 0) }
|
|
||||||
|
|
||||||
// We'll start out by creating an invoice and writing it to the DB.
|
|
||||||
amt := lnwire.NewMSatFromSatoshis(1000)
|
|
||||||
invoice, err := randInvoice(amt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
payHash := invoice.Terms.PaymentPreimage.Hash()
|
|
||||||
|
|
||||||
if _, err := db.AddInvoice(invoice, payHash); err != nil {
|
|
||||||
t.Fatalf("unable to add invoice %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the invoice in the DB, we'll now attempt to settle the invoice.
|
|
||||||
dbInvoice, err := db.UpdateInvoice(
|
|
||||||
payHash, getUpdateInvoice(amt),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll update what we expect the settle invoice to be so that our
|
|
||||||
// comparison below has the correct assumption.
|
|
||||||
invoice.SettleIndex = 1
|
|
||||||
invoice.Terms.State = ContractSettled
|
|
||||||
invoice.AmtPaid = amt
|
|
||||||
invoice.SettleDate = dbInvoice.SettleDate
|
|
||||||
invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{
|
|
||||||
{}: {
|
|
||||||
Amt: amt,
|
|
||||||
AcceptTime: time.Unix(1, 0),
|
|
||||||
ResolveTime: time.Unix(1, 0),
|
|
||||||
State: HtlcStateSettled,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should get back the exact same invoice that we just inserted.
|
|
||||||
if !reflect.DeepEqual(dbInvoice, invoice) {
|
|
||||||
t.Fatalf("wrong invoice after settle, expected %v got %v",
|
|
||||||
spew.Sdump(invoice), spew.Sdump(dbInvoice))
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we try to settle the invoice again, then we should get the very
|
|
||||||
// same invoice back, but with an error this time.
|
|
||||||
dbInvoice, err = db.UpdateInvoice(
|
|
||||||
payHash, getUpdateInvoice(amt),
|
|
||||||
)
|
|
||||||
if err != ErrInvoiceAlreadySettled {
|
|
||||||
t.Fatalf("expected ErrInvoiceAlreadySettled")
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbInvoice == nil {
|
|
||||||
t.Fatalf("invoice from db is nil after settle!")
|
|
||||||
}
|
|
||||||
|
|
||||||
invoice.SettleDate = dbInvoice.SettleDate
|
|
||||||
if !reflect.DeepEqual(dbInvoice, invoice) {
|
|
||||||
t.Fatalf("wrong invoice after second settle, expected %v got %v",
|
|
||||||
spew.Sdump(invoice), spew.Sdump(dbInvoice))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestQueryInvoices ensures that we can properly query the invoice database for
|
|
||||||
// invoices using different types of queries.
|
|
||||||
func TestQueryInvoices(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanUp, err := makeTestDB()
|
|
||||||
defer cleanUp()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// To begin the test, we'll add 50 invoices to the database. We'll
|
|
||||||
// assume that the index of the invoice within the database is the same
|
|
||||||
// as the amount of the invoice itself.
|
|
||||||
const numInvoices = 50
|
|
||||||
for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ {
|
|
||||||
invoice, err := randInvoice(i)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
paymentHash := invoice.Terms.PaymentPreimage.Hash()
|
|
||||||
|
|
||||||
if _, err := db.AddInvoice(invoice, paymentHash); err != nil {
|
|
||||||
t.Fatalf("unable to add invoice: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll only settle half of all invoices created.
|
|
||||||
if i%2 == 0 {
|
|
||||||
_, err := db.UpdateInvoice(
|
|
||||||
paymentHash, getUpdateInvoice(i),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to settle invoice: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll then retrieve the set of all invoices and pending invoices.
|
|
||||||
// This will serve useful when comparing the expected responses of the
|
|
||||||
// query with the actual ones.
|
|
||||||
invoices, err := db.FetchAllInvoices(false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to retrieve invoices: %v", err)
|
|
||||||
}
|
|
||||||
pendingInvoices, err := db.FetchAllInvoices(true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to retrieve pending invoices: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The test will consist of several queries along with their respective
|
|
||||||
// expected response. Each query response should match its expected one.
|
|
||||||
testCases := []struct {
|
|
||||||
query InvoiceQuery
|
|
||||||
expected []Invoice
|
|
||||||
}{
|
|
||||||
// Fetch all invoices with a single query.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: invoices,
|
|
||||||
},
|
|
||||||
// Fetch all invoices with a single query, reversed.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: invoices,
|
|
||||||
},
|
|
||||||
// Fetch the first 25 invoices.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
NumMaxInvoices: numInvoices / 2,
|
|
||||||
},
|
|
||||||
expected: invoices[:numInvoices/2],
|
|
||||||
},
|
|
||||||
// Fetch the first 10 invoices, but this time iterating
|
|
||||||
// backwards.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 11,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: invoices[:10],
|
|
||||||
},
|
|
||||||
// Fetch the last 40 invoices.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 10,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: invoices[10:],
|
|
||||||
},
|
|
||||||
// Fetch all but the first invoice.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 1,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: invoices[1:],
|
|
||||||
},
|
|
||||||
// Fetch one invoice, reversed, with index offset 3. This
|
|
||||||
// should give us the second invoice in the array.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 3,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[1:2],
|
|
||||||
},
|
|
||||||
// Same as above, at index 2.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 2,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[0:1],
|
|
||||||
},
|
|
||||||
// Fetch one invoice, at index 1, reversed. Since invoice#1 is
|
|
||||||
// the very first, there won't be any left in a reverse search,
|
|
||||||
// so we expect no invoices to be returned.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 1,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: nil,
|
|
||||||
},
|
|
||||||
// Same as above, but don't restrict the number of invoices to
|
|
||||||
// 1.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 1,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: nil,
|
|
||||||
},
|
|
||||||
// Fetch one invoice, reversed, with no offset set. We expect
|
|
||||||
// the last invoice in the response.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[numInvoices-1:],
|
|
||||||
},
|
|
||||||
// Fetch one invoice, reversed, the offset set at numInvoices+1.
|
|
||||||
// We expect this to return the last invoice.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: numInvoices + 1,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[numInvoices-1:],
|
|
||||||
},
|
|
||||||
// Same as above, at offset numInvoices.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: numInvoices,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[numInvoices-2 : numInvoices-1],
|
|
||||||
},
|
|
||||||
// Fetch one invoice, at no offset (same as offset 0). We
|
|
||||||
// expect the first invoice only in the response.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[:1],
|
|
||||||
},
|
|
||||||
// Same as above, at offset 1.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 1,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[1:2],
|
|
||||||
},
|
|
||||||
// Same as above, at offset 2.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 2,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[2:3],
|
|
||||||
},
|
|
||||||
// Same as above, at offset numInvoices-1. Expect the last
|
|
||||||
// invoice to be returned.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: numInvoices - 1,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: invoices[numInvoices-1:],
|
|
||||||
},
|
|
||||||
// Same as above, at offset numInvoices. No invoices should be
|
|
||||||
// returned, as there are no invoices after this offset.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: numInvoices,
|
|
||||||
NumMaxInvoices: 1,
|
|
||||||
},
|
|
||||||
expected: nil,
|
|
||||||
},
|
|
||||||
// Fetch all pending invoices with a single query.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
PendingOnly: true,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
expected: pendingInvoices,
|
|
||||||
},
|
|
||||||
// Fetch the first 12 pending invoices.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
PendingOnly: true,
|
|
||||||
NumMaxInvoices: numInvoices / 4,
|
|
||||||
},
|
|
||||||
expected: pendingInvoices[:len(pendingInvoices)/2],
|
|
||||||
},
|
|
||||||
// Fetch the first 5 pending invoices, but this time iterating
|
|
||||||
// backwards.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 10,
|
|
||||||
PendingOnly: true,
|
|
||||||
Reversed: true,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
// Since we seek to the invoice with index 10 and
|
|
||||||
// iterate backwards, there should only be 5 pending
|
|
||||||
// invoices before it as every other invoice within the
|
|
||||||
// index is settled.
|
|
||||||
expected: pendingInvoices[:5],
|
|
||||||
},
|
|
||||||
// Fetch the last 15 invoices.
|
|
||||||
{
|
|
||||||
query: InvoiceQuery{
|
|
||||||
IndexOffset: 20,
|
|
||||||
PendingOnly: true,
|
|
||||||
NumMaxInvoices: numInvoices,
|
|
||||||
},
|
|
||||||
// Since we seek to the invoice with index 20, there are
|
|
||||||
// 30 invoices left. From these 30, only 15 of them are
|
|
||||||
// still pending.
|
|
||||||
expected: pendingInvoices[len(pendingInvoices)-15:],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, testCase := range testCases {
|
|
||||||
response, err := db.QueryInvoices(testCase.query)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to query invoice database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(response.Invoices, testCase.expected) {
|
|
||||||
t.Fatalf("test #%d: query returned incorrect set of "+
|
|
||||||
"invoices: expcted %v, got %v", i,
|
|
||||||
spew.Sdump(response.Invoices),
|
|
||||||
spew.Sdump(testCase.expected))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getUpdateInvoice returns an invoice update callback that, when called,
|
|
||||||
// settles the invoice with the given amount.
|
|
||||||
func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
|
|
||||||
return func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
|
|
||||||
if invoice.Terms.State == ContractSettled {
|
|
||||||
return nil, ErrInvoiceAlreadySettled
|
|
||||||
}
|
|
||||||
|
|
||||||
update := &InvoiceUpdateDesc{
|
|
||||||
Preimage: invoice.Terms.PaymentPreimage,
|
|
||||||
State: ContractSettled,
|
|
||||||
Htlcs: map[CircuitKey]*HtlcAcceptDesc{
|
|
||||||
{}: {
|
|
||||||
Amt: amt,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
return update, nil
|
|
||||||
}
|
|
||||||
}
|
|
@ -3,7 +3,6 @@ package migration_01_to_11
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
@ -16,9 +15,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// UnknownPreimage is an all-zeroes preimage that indicates that the
|
|
||||||
// preimage for this invoice is not yet known.
|
|
||||||
UnknownPreimage lntypes.Preimage
|
|
||||||
|
|
||||||
// invoiceBucket is the name of the bucket within the database that
|
// invoiceBucket is the name of the bucket within the database that
|
||||||
// stores all data related to invoices no matter their final state.
|
// stores all data related to invoices no matter their final state.
|
||||||
@ -26,23 +22,6 @@ var (
|
|||||||
// which is a monotonically increasing uint32.
|
// which is a monotonically increasing uint32.
|
||||||
invoiceBucket = []byte("invoices")
|
invoiceBucket = []byte("invoices")
|
||||||
|
|
||||||
// paymentHashIndexBucket is the name of the sub-bucket within the
|
|
||||||
// invoiceBucket which indexes all invoices by their payment hash. The
|
|
||||||
// payment hash is the sha256 of the invoice's payment preimage. This
|
|
||||||
// index is used to detect duplicates, and also to provide a fast path
|
|
||||||
// for looking up incoming HTLCs to determine if we're able to settle
|
|
||||||
// them fully.
|
|
||||||
//
|
|
||||||
// maps: payHash => invoiceKey
|
|
||||||
invoiceIndexBucket = []byte("paymenthashes")
|
|
||||||
|
|
||||||
// numInvoicesKey is the name of key which houses the auto-incrementing
|
|
||||||
// invoice ID which is essentially used as a primary key. With each
|
|
||||||
// invoice inserted, the primary key is incremented by one. This key is
|
|
||||||
// stored within the invoiceIndexBucket. Within the invoiceBucket
|
|
||||||
// invoices are uniquely identified by the invoice ID.
|
|
||||||
numInvoicesKey = []byte("nik")
|
|
||||||
|
|
||||||
// addIndexBucket is an index bucket that we'll use to create a
|
// addIndexBucket is an index bucket that we'll use to create a
|
||||||
// monotonically increasing set of add indexes. Each time we add a new
|
// monotonically increasing set of add indexes. Each time we add a new
|
||||||
// invoice, this sequence number will be incremented and then populated
|
// invoice, this sequence number will be incremented and then populated
|
||||||
@ -62,21 +41,6 @@ var (
|
|||||||
//
|
//
|
||||||
// settleIndexNo => invoiceKey
|
// settleIndexNo => invoiceKey
|
||||||
settleIndexBucket = []byte("invoice-settle-index")
|
settleIndexBucket = []byte("invoice-settle-index")
|
||||||
|
|
||||||
// ErrInvoiceAlreadySettled is returned when the invoice is already
|
|
||||||
// settled.
|
|
||||||
ErrInvoiceAlreadySettled = errors.New("invoice already settled")
|
|
||||||
|
|
||||||
// ErrInvoiceAlreadyCanceled is returned when the invoice is already
|
|
||||||
// canceled.
|
|
||||||
ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled")
|
|
||||||
|
|
||||||
// ErrInvoiceAlreadyAccepted is returned when the invoice is already
|
|
||||||
// accepted.
|
|
||||||
ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted")
|
|
||||||
|
|
||||||
// ErrInvoiceStillOpen is returned when the invoice is still open.
|
|
||||||
ErrInvoiceStillOpen = errors.New("invoice still open")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -237,18 +201,6 @@ type Invoice struct {
|
|||||||
// HtlcState defines the states an htlc paying to an invoice can be in.
|
// HtlcState defines the states an htlc paying to an invoice can be in.
|
||||||
type HtlcState uint8
|
type HtlcState uint8
|
||||||
|
|
||||||
const (
|
|
||||||
// HtlcStateAccepted indicates the htlc is locked-in, but not resolved.
|
|
||||||
HtlcStateAccepted HtlcState = iota
|
|
||||||
|
|
||||||
// HtlcStateCanceled indicates the htlc is canceled back to the
|
|
||||||
// sender.
|
|
||||||
HtlcStateCanceled
|
|
||||||
|
|
||||||
// HtlcStateSettled indicates the htlc is settled.
|
|
||||||
HtlcStateSettled
|
|
||||||
)
|
|
||||||
|
|
||||||
// InvoiceHTLC contains details about an htlc paying to this invoice.
|
// InvoiceHTLC contains details about an htlc paying to this invoice.
|
||||||
type InvoiceHTLC struct {
|
type InvoiceHTLC struct {
|
||||||
// Amt is the amount that is carried by this htlc.
|
// Amt is the amount that is carried by this htlc.
|
||||||
@ -276,37 +228,6 @@ type InvoiceHTLC struct {
|
|||||||
State HtlcState
|
State HtlcState
|
||||||
}
|
}
|
||||||
|
|
||||||
// HtlcAcceptDesc describes the details of a newly accepted htlc.
|
|
||||||
type HtlcAcceptDesc struct {
|
|
||||||
// AcceptHeight is the block height at which this htlc was accepted.
|
|
||||||
AcceptHeight int32
|
|
||||||
|
|
||||||
// Amt is the amount that is carried by this htlc.
|
|
||||||
Amt lnwire.MilliSatoshi
|
|
||||||
|
|
||||||
// Expiry is the expiry height of this htlc.
|
|
||||||
Expiry uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvoiceUpdateDesc describes the changes that should be applied to the
|
|
||||||
// invoice.
|
|
||||||
type InvoiceUpdateDesc struct {
|
|
||||||
// State is the new state that this invoice should progress to.
|
|
||||||
State ContractState
|
|
||||||
|
|
||||||
// Htlcs describes the changes that need to be made to the invoice htlcs
|
|
||||||
// in the database. Htlc map entries with their value set should be
|
|
||||||
// added. If the map value is nil, the htlc should be canceled.
|
|
||||||
Htlcs map[CircuitKey]*HtlcAcceptDesc
|
|
||||||
|
|
||||||
// Preimage must be set to the preimage when state is settled.
|
|
||||||
Preimage lntypes.Preimage
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvoiceUpdateCallback is a callback used in the db transaction to update the
|
|
||||||
// invoice.
|
|
||||||
type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error)
|
|
||||||
|
|
||||||
func validateInvoice(i *Invoice) error {
|
func validateInvoice(i *Invoice) error {
|
||||||
if len(i.Memo) > MaxMemoSize {
|
if len(i.Memo) > MaxMemoSize {
|
||||||
return fmt.Errorf("max length a memo is %v, and invoice "+
|
return fmt.Errorf("max length a memo is %v, and invoice "+
|
||||||
@ -325,186 +246,6 @@ func validateInvoice(i *Invoice) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddInvoice inserts the targeted invoice into the database. If the invoice has
|
|
||||||
// *any* payment hashes which already exists within the database, then the
|
|
||||||
// insertion will be aborted and rejected due to the strict policy banning any
|
|
||||||
// duplicate payment hashes. A side effect of this function is that it sets
|
|
||||||
// AddIndex on newInvoice.
|
|
||||||
func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) (
|
|
||||||
uint64, error) {
|
|
||||||
|
|
||||||
if err := validateInvoice(newInvoice); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var invoiceAddIndex uint64
|
|
||||||
err := d.Update(func(tx *bbolt.Tx) error {
|
|
||||||
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
invoiceIndex, err := invoices.CreateBucketIfNotExists(
|
|
||||||
invoiceIndexBucket,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
addIndex, err := invoices.CreateBucketIfNotExists(
|
|
||||||
addIndexBucket,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure that an invoice an identical payment hash doesn't
|
|
||||||
// already exist within the index.
|
|
||||||
if invoiceIndex.Get(paymentHash[:]) != nil {
|
|
||||||
return ErrDuplicateInvoice
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the current running payment ID counter hasn't yet been
|
|
||||||
// created, then create it now.
|
|
||||||
var invoiceNum uint32
|
|
||||||
invoiceCounter := invoiceIndex.Get(numInvoicesKey)
|
|
||||||
if invoiceCounter == nil {
|
|
||||||
var scratch [4]byte
|
|
||||||
byteOrder.PutUint32(scratch[:], invoiceNum)
|
|
||||||
err := invoiceIndex.Put(numInvoicesKey, scratch[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
invoiceNum = byteOrder.Uint32(invoiceCounter)
|
|
||||||
}
|
|
||||||
|
|
||||||
newIndex, err := putInvoice(
|
|
||||||
invoices, invoiceIndex, addIndex, newInvoice, invoiceNum,
|
|
||||||
paymentHash,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
invoiceAddIndex = newIndex
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return invoiceAddIndex, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvoicesAddedSince can be used by callers to seek into the event time series
|
|
||||||
// of all the invoices added in the database. The specified sinceAddIndex
|
|
||||||
// should be the highest add index that the caller knows of. This method will
|
|
||||||
// return all invoices with an add index greater than the specified
|
|
||||||
// sinceAddIndex.
|
|
||||||
//
|
|
||||||
// NOTE: The index starts from 1, as a result. We enforce that specifying a
|
|
||||||
// value below the starting index value is a noop.
|
|
||||||
func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) {
|
|
||||||
var newInvoices []Invoice
|
|
||||||
|
|
||||||
// If an index of zero was specified, then in order to maintain
|
|
||||||
// backwards compat, we won't send out any new invoices.
|
|
||||||
if sinceAddIndex == 0 {
|
|
||||||
return newInvoices, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var startIndex [8]byte
|
|
||||||
byteOrder.PutUint64(startIndex[:], sinceAddIndex)
|
|
||||||
|
|
||||||
err := d.DB.View(func(tx *bbolt.Tx) error {
|
|
||||||
invoices := tx.Bucket(invoiceBucket)
|
|
||||||
if invoices == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
|
|
||||||
addIndex := invoices.Bucket(addIndexBucket)
|
|
||||||
if addIndex == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now run through each entry in the add index starting
|
|
||||||
// at our starting index. We'll continue until we reach the
|
|
||||||
// very end of the current key space.
|
|
||||||
invoiceCursor := addIndex.Cursor()
|
|
||||||
|
|
||||||
// We'll seek to the starting index, then manually advance the
|
|
||||||
// cursor in order to skip the entry with the since add index.
|
|
||||||
invoiceCursor.Seek(startIndex[:])
|
|
||||||
addSeqNo, invoiceKey := invoiceCursor.Next()
|
|
||||||
|
|
||||||
for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() {
|
|
||||||
|
|
||||||
// For each key found, we'll look up the actual
|
|
||||||
// invoice, then accumulate it into our return value.
|
|
||||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
newInvoices = append(newInvoices, invoice)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
switch {
|
|
||||||
// If no invoices have been created, then we'll return the empty set of
|
|
||||||
// invoices.
|
|
||||||
case err == ErrNoInvoicesCreated:
|
|
||||||
|
|
||||||
case err != nil:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return newInvoices, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LookupInvoice attempts to look up an invoice according to its 32 byte
|
|
||||||
// payment hash. If an invoice which can settle the HTLC identified by the
|
|
||||||
// passed payment hash isn't found, then an error is returned. Otherwise, the
|
|
||||||
// full invoice is returned. Before setting the incoming HTLC, the values
|
|
||||||
// SHOULD be checked to ensure the payer meets the agreed upon contractual
|
|
||||||
// terms of the payment.
|
|
||||||
func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) {
|
|
||||||
var invoice Invoice
|
|
||||||
err := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
invoices := tx.Bucket(invoiceBucket)
|
|
||||||
if invoices == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
invoiceIndex := invoices.Bucket(invoiceIndexBucket)
|
|
||||||
if invoiceIndex == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the invoice index to see if an invoice paying to this
|
|
||||||
// hash exists within the DB.
|
|
||||||
invoiceNum := invoiceIndex.Get(paymentHash[:])
|
|
||||||
if invoiceNum == nil {
|
|
||||||
return ErrInvoiceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// An invoice matching the payment hash has been found, so
|
|
||||||
// retrieve the record of the invoice itself.
|
|
||||||
i, err := fetchInvoice(invoiceNum, invoices)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
invoice = i
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return invoice, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return invoice, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchAllInvoices returns all invoices currently stored within the database.
|
// FetchAllInvoices returns all invoices currently stored within the database.
|
||||||
// If the pendingOnly param is true, then only unsettled invoices will be
|
// If the pendingOnly param is true, then only unsettled invoices will be
|
||||||
// returned, skipping all invoices that are fully settled.
|
// returned, skipping all invoices that are fully settled.
|
||||||
@ -549,343 +290,6 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) {
|
|||||||
return invoices, nil
|
return invoices, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InvoiceQuery represents a query to the invoice database. The query allows a
|
|
||||||
// caller to retrieve all invoices starting from a particular add index and
|
|
||||||
// limit the number of results returned.
|
|
||||||
type InvoiceQuery struct {
|
|
||||||
// IndexOffset is the offset within the add indices to start at. This
|
|
||||||
// can be used to start the response at a particular invoice.
|
|
||||||
IndexOffset uint64
|
|
||||||
|
|
||||||
// NumMaxInvoices is the maximum number of invoices that should be
|
|
||||||
// starting from the add index.
|
|
||||||
NumMaxInvoices uint64
|
|
||||||
|
|
||||||
// PendingOnly, if set, returns unsettled invoices starting from the
|
|
||||||
// add index.
|
|
||||||
PendingOnly bool
|
|
||||||
|
|
||||||
// Reversed, if set, indicates that the invoices returned should start
|
|
||||||
// from the IndexOffset and go backwards.
|
|
||||||
Reversed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvoiceSlice is the response to a invoice query. It includes the original
|
|
||||||
// query, the set of invoices that match the query, and an integer which
|
|
||||||
// represents the offset index of the last item in the set of returned invoices.
|
|
||||||
// This integer allows callers to resume their query using this offset in the
|
|
||||||
// event that the query's response exceeds the maximum number of returnable
|
|
||||||
// invoices.
|
|
||||||
type InvoiceSlice struct {
|
|
||||||
InvoiceQuery
|
|
||||||
|
|
||||||
// Invoices is the set of invoices that matched the query above.
|
|
||||||
Invoices []Invoice
|
|
||||||
|
|
||||||
// FirstIndexOffset is the index of the first element in the set of
|
|
||||||
// returned Invoices above. Callers can use this to resume their query
|
|
||||||
// in the event that the slice has too many events to fit into a single
|
|
||||||
// response.
|
|
||||||
FirstIndexOffset uint64
|
|
||||||
|
|
||||||
// LastIndexOffset is the index of the last element in the set of
|
|
||||||
// returned Invoices above. Callers can use this to resume their query
|
|
||||||
// in the event that the slice has too many events to fit into a single
|
|
||||||
// response.
|
|
||||||
LastIndexOffset uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryInvoices allows a caller to query the invoice database for invoices
|
|
||||||
// within the specified add index range.
|
|
||||||
func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) {
|
|
||||||
resp := InvoiceSlice{
|
|
||||||
InvoiceQuery: q,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := d.View(func(tx *bbolt.Tx) error {
|
|
||||||
// If the bucket wasn't found, then there aren't any invoices
|
|
||||||
// within the database yet, so we can simply exit.
|
|
||||||
invoices := tx.Bucket(invoiceBucket)
|
|
||||||
if invoices == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
invoiceAddIndex := invoices.Bucket(addIndexBucket)
|
|
||||||
if invoiceAddIndex == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
|
|
||||||
// keyForIndex is a helper closure that retrieves the invoice
|
|
||||||
// key for the given add index of an invoice.
|
|
||||||
keyForIndex := func(c *bbolt.Cursor, index uint64) []byte {
|
|
||||||
var keyIndex [8]byte
|
|
||||||
byteOrder.PutUint64(keyIndex[:], index)
|
|
||||||
_, invoiceKey := c.Seek(keyIndex[:])
|
|
||||||
return invoiceKey
|
|
||||||
}
|
|
||||||
|
|
||||||
// nextKey is a helper closure to determine what the next
|
|
||||||
// invoice key is when iterating over the invoice add index.
|
|
||||||
nextKey := func(c *bbolt.Cursor) ([]byte, []byte) {
|
|
||||||
if q.Reversed {
|
|
||||||
return c.Prev()
|
|
||||||
}
|
|
||||||
return c.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll be using a cursor to seek into the database and return
|
|
||||||
// a slice of invoices. We'll need to determine where to start
|
|
||||||
// our cursor depending on the parameters set within the query.
|
|
||||||
c := invoiceAddIndex.Cursor()
|
|
||||||
invoiceKey := keyForIndex(c, q.IndexOffset+1)
|
|
||||||
|
|
||||||
// If the query is specifying reverse iteration, then we must
|
|
||||||
// handle a few offset cases.
|
|
||||||
if q.Reversed {
|
|
||||||
switch q.IndexOffset {
|
|
||||||
|
|
||||||
// This indicates the default case, where no offset was
|
|
||||||
// specified. In that case we just start from the last
|
|
||||||
// invoice.
|
|
||||||
case 0:
|
|
||||||
_, invoiceKey = c.Last()
|
|
||||||
|
|
||||||
// This indicates the offset being set to the very
|
|
||||||
// first invoice. Since there are no invoices before
|
|
||||||
// this offset, and the direction is reversed, we can
|
|
||||||
// return without adding any invoices to the response.
|
|
||||||
case 1:
|
|
||||||
return nil
|
|
||||||
|
|
||||||
// Otherwise we start iteration at the invoice prior to
|
|
||||||
// the offset.
|
|
||||||
default:
|
|
||||||
invoiceKey = keyForIndex(c, q.IndexOffset-1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we know that a set of invoices exists, then we'll begin
|
|
||||||
// our seek through the bucket in order to satisfy the query.
|
|
||||||
// We'll continue until either we reach the end of the range, or
|
|
||||||
// reach our max number of invoices.
|
|
||||||
for ; invoiceKey != nil; _, invoiceKey = nextKey(c) {
|
|
||||||
// If our current return payload exceeds the max number
|
|
||||||
// of invoices, then we'll exit now.
|
|
||||||
if uint64(len(resp.Invoices)) >= q.NumMaxInvoices {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip any settled invoices if the caller is only
|
|
||||||
// interested in unsettled.
|
|
||||||
if q.PendingOnly &&
|
|
||||||
invoice.Terms.State == ContractSettled {
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// At this point, we've exhausted the offset, so we'll
|
|
||||||
// begin collecting invoices found within the range.
|
|
||||||
resp.Invoices = append(resp.Invoices, invoice)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we iterated through the add index in reverse order, then
|
|
||||||
// we'll need to reverse the slice of invoices to return them in
|
|
||||||
// forward order.
|
|
||||||
if q.Reversed {
|
|
||||||
numInvoices := len(resp.Invoices)
|
|
||||||
for i := 0; i < numInvoices/2; i++ {
|
|
||||||
opposite := numInvoices - i - 1
|
|
||||||
resp.Invoices[i], resp.Invoices[opposite] =
|
|
||||||
resp.Invoices[opposite], resp.Invoices[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil && err != ErrNoInvoicesCreated {
|
|
||||||
return resp, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, record the indexes of the first and last invoices returned
|
|
||||||
// so that the caller can resume from this point later on.
|
|
||||||
if len(resp.Invoices) > 0 {
|
|
||||||
resp.FirstIndexOffset = resp.Invoices[0].AddIndex
|
|
||||||
resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateInvoice attempts to update an invoice corresponding to the passed
|
|
||||||
// payment hash. If an invoice matching the passed payment hash doesn't exist
|
|
||||||
// within the database, then the action will fail with a "not found" error.
|
|
||||||
//
|
|
||||||
// The update is performed inside the same database transaction that fetches the
|
|
||||||
// invoice and is therefore atomic. The fields to update are controlled by the
|
|
||||||
// supplied callback.
|
|
||||||
func (d *DB) UpdateInvoice(paymentHash lntypes.Hash,
|
|
||||||
callback InvoiceUpdateCallback) (*Invoice, error) {
|
|
||||||
|
|
||||||
var updatedInvoice *Invoice
|
|
||||||
err := d.Update(func(tx *bbolt.Tx) error {
|
|
||||||
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
invoiceIndex, err := invoices.CreateBucketIfNotExists(
|
|
||||||
invoiceIndexBucket,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
settleIndex, err := invoices.CreateBucketIfNotExists(
|
|
||||||
settleIndexBucket,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the invoice index to see if an invoice paying to this
|
|
||||||
// hash exists within the DB.
|
|
||||||
invoiceNum := invoiceIndex.Get(paymentHash[:])
|
|
||||||
if invoiceNum == nil {
|
|
||||||
return ErrInvoiceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedInvoice, err = d.updateInvoice(
|
|
||||||
paymentHash, invoices, settleIndex, invoiceNum,
|
|
||||||
callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
|
|
||||||
return updatedInvoice, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// InvoicesSettledSince can be used by callers to catch up any settled invoices
|
|
||||||
// they missed within the settled invoice time series. We'll return all known
|
|
||||||
// settled invoice that have a settle index higher than the passed
|
|
||||||
// sinceSettleIndex.
|
|
||||||
//
|
|
||||||
// NOTE: The index starts from 1, as a result. We enforce that specifying a
|
|
||||||
// value below the starting index value is a noop.
|
|
||||||
func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
|
|
||||||
var settledInvoices []Invoice
|
|
||||||
|
|
||||||
// If an index of zero was specified, then in order to maintain
|
|
||||||
// backwards compat, we won't send out any new invoices.
|
|
||||||
if sinceSettleIndex == 0 {
|
|
||||||
return settledInvoices, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var startIndex [8]byte
|
|
||||||
byteOrder.PutUint64(startIndex[:], sinceSettleIndex)
|
|
||||||
|
|
||||||
err := d.DB.View(func(tx *bbolt.Tx) error {
|
|
||||||
invoices := tx.Bucket(invoiceBucket)
|
|
||||||
if invoices == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
|
|
||||||
settleIndex := invoices.Bucket(settleIndexBucket)
|
|
||||||
if settleIndex == nil {
|
|
||||||
return ErrNoInvoicesCreated
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now run through each entry in the add index starting
|
|
||||||
// at our starting index. We'll continue until we reach the
|
|
||||||
// very end of the current key space.
|
|
||||||
invoiceCursor := settleIndex.Cursor()
|
|
||||||
|
|
||||||
// We'll seek to the starting index, then manually advance the
|
|
||||||
// cursor in order to skip the entry with the since add index.
|
|
||||||
invoiceCursor.Seek(startIndex[:])
|
|
||||||
seqNo, invoiceKey := invoiceCursor.Next()
|
|
||||||
|
|
||||||
for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() {
|
|
||||||
|
|
||||||
// For each key found, we'll look up the actual
|
|
||||||
// invoice, then accumulate it into our return value.
|
|
||||||
invoice, err := fetchInvoice(invoiceKey, invoices)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
settledInvoices = append(settledInvoices, invoice)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return settledInvoices, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket,
|
|
||||||
i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) (
|
|
||||||
uint64, error) {
|
|
||||||
|
|
||||||
// Create the invoice key which is just the big-endian representation
|
|
||||||
// of the invoice number.
|
|
||||||
var invoiceKey [4]byte
|
|
||||||
byteOrder.PutUint32(invoiceKey[:], invoiceNum)
|
|
||||||
|
|
||||||
// Increment the num invoice counter index so the next invoice bares
|
|
||||||
// the proper ID.
|
|
||||||
var scratch [4]byte
|
|
||||||
invoiceCounter := invoiceNum + 1
|
|
||||||
byteOrder.PutUint32(scratch[:], invoiceCounter)
|
|
||||||
if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the payment hash to the invoice index. This will let us quickly
|
|
||||||
// identify if we can settle an incoming payment, and also to possibly
|
|
||||||
// allow a single invoice to have multiple payment installations.
|
|
||||||
err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll obtain the next add invoice index (sequence
|
|
||||||
// number), so we can properly place this invoice within this
|
|
||||||
// event stream.
|
|
||||||
nextAddSeqNo, err := addIndex.NextSequence()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the next sequence obtained, we'll updating the event series in
|
|
||||||
// the add index bucket to map this current add counter to the index of
|
|
||||||
// this new invoice.
|
|
||||||
var seqNoBytes [8]byte
|
|
||||||
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
|
|
||||||
if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i.AddIndex = nextAddSeqNo
|
|
||||||
|
|
||||||
// Finally, serialize the invoice itself to be written to the disk.
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := serializeInvoice(&buf, i); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nextAddSeqNo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// serializeInvoice serializes an invoice to a writer.
|
// serializeInvoice serializes an invoice to a writer.
|
||||||
//
|
//
|
||||||
// Note: this function is in use for a migration. Before making changes that
|
// Note: this function is in use for a migration. Before making changes that
|
||||||
@ -1006,17 +410,6 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) {
|
|
||||||
invoiceBytes := invoices.Get(invoiceNum)
|
|
||||||
if invoiceBytes == nil {
|
|
||||||
return Invoice{}, ErrInvoiceNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
invoiceReader := bytes.NewReader(invoiceBytes)
|
|
||||||
|
|
||||||
return deserializeInvoice(invoiceReader)
|
|
||||||
}
|
|
||||||
|
|
||||||
func deserializeInvoice(r io.Reader) (Invoice, error) {
|
func deserializeInvoice(r io.Reader) (Invoice, error) {
|
||||||
var err error
|
var err error
|
||||||
invoice := Invoice{}
|
invoice := Invoice{}
|
||||||
@ -1155,166 +548,3 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
|
|||||||
|
|
||||||
return htlcs, nil
|
return htlcs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// copySlice allocates a new slice and copies the source into it.
|
|
||||||
func copySlice(src []byte) []byte {
|
|
||||||
dest := make([]byte, len(src))
|
|
||||||
copy(dest, src)
|
|
||||||
return dest
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyInvoice makes a deep copy of the supplied invoice.
|
|
||||||
func copyInvoice(src *Invoice) *Invoice {
|
|
||||||
dest := Invoice{
|
|
||||||
Memo: copySlice(src.Memo),
|
|
||||||
Receipt: copySlice(src.Receipt),
|
|
||||||
PaymentRequest: copySlice(src.PaymentRequest),
|
|
||||||
FinalCltvDelta: src.FinalCltvDelta,
|
|
||||||
CreationDate: src.CreationDate,
|
|
||||||
SettleDate: src.SettleDate,
|
|
||||||
Terms: src.Terms,
|
|
||||||
AddIndex: src.AddIndex,
|
|
||||||
SettleIndex: src.SettleIndex,
|
|
||||||
AmtPaid: src.AmtPaid,
|
|
||||||
Htlcs: make(
|
|
||||||
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range src.Htlcs {
|
|
||||||
dest.Htlcs[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return &dest
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateInvoice fetches the invoice, obtains the update descriptor from the
|
|
||||||
// callback and applies the updates in a single db transaction.
|
|
||||||
func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket,
|
|
||||||
invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) {
|
|
||||||
|
|
||||||
invoice, err := fetchInvoice(invoiceNum, invoices)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
preUpdateState := invoice.Terms.State
|
|
||||||
|
|
||||||
// Create deep copy to prevent any accidental modification in the
|
|
||||||
// callback.
|
|
||||||
copy := copyInvoice(&invoice)
|
|
||||||
|
|
||||||
// Call the callback and obtain the update descriptor.
|
|
||||||
update, err := callback(copy)
|
|
||||||
if err != nil {
|
|
||||||
return &invoice, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update invoice state.
|
|
||||||
invoice.Terms.State = update.State
|
|
||||||
|
|
||||||
now := d.now()
|
|
||||||
|
|
||||||
// Update htlc set.
|
|
||||||
for key, htlcUpdate := range update.Htlcs {
|
|
||||||
htlc, ok := invoice.Htlcs[key]
|
|
||||||
|
|
||||||
// No update means the htlc needs to be canceled.
|
|
||||||
if htlcUpdate == nil {
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("unknown htlc %v", key)
|
|
||||||
}
|
|
||||||
if htlc.State != HtlcStateAccepted {
|
|
||||||
return nil, fmt.Errorf("can only cancel " +
|
|
||||||
"accepted htlcs")
|
|
||||||
}
|
|
||||||
|
|
||||||
htlc.State = HtlcStateCanceled
|
|
||||||
htlc.ResolveTime = now
|
|
||||||
invoice.AmtPaid -= htlc.Amt
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add new htlc paying to the invoice.
|
|
||||||
if ok {
|
|
||||||
return nil, fmt.Errorf("htlc %v already exists", key)
|
|
||||||
}
|
|
||||||
htlc = &InvoiceHTLC{
|
|
||||||
Amt: htlcUpdate.Amt,
|
|
||||||
Expiry: htlcUpdate.Expiry,
|
|
||||||
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
|
|
||||||
AcceptTime: now,
|
|
||||||
}
|
|
||||||
if preUpdateState == ContractSettled {
|
|
||||||
htlc.State = HtlcStateSettled
|
|
||||||
htlc.ResolveTime = now
|
|
||||||
} else {
|
|
||||||
htlc.State = HtlcStateAccepted
|
|
||||||
}
|
|
||||||
|
|
||||||
invoice.Htlcs[key] = htlc
|
|
||||||
invoice.AmtPaid += htlc.Amt
|
|
||||||
}
|
|
||||||
|
|
||||||
// If invoice moved to the settled state, update settle index and settle
|
|
||||||
// time.
|
|
||||||
if preUpdateState != invoice.Terms.State &&
|
|
||||||
invoice.Terms.State == ContractSettled {
|
|
||||||
|
|
||||||
if update.Preimage.Hash() != hash {
|
|
||||||
return nil, fmt.Errorf("preimage does not match")
|
|
||||||
}
|
|
||||||
invoice.Terms.PaymentPreimage = update.Preimage
|
|
||||||
|
|
||||||
// Settle all accepted htlcs.
|
|
||||||
for _, htlc := range invoice.Htlcs {
|
|
||||||
if htlc.State != HtlcStateAccepted {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
htlc.State = HtlcStateSettled
|
|
||||||
htlc.ResolveTime = now
|
|
||||||
}
|
|
||||||
|
|
||||||
err := setSettleFields(settleIndex, invoiceNum, &invoice, now)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
if err := serializeInvoice(&buf, &invoice); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &invoice, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte,
|
|
||||||
invoice *Invoice, now time.Time) error {
|
|
||||||
|
|
||||||
// Now that we know the invoice hasn't already been settled, we'll
|
|
||||||
// update the settle index so we can place this settle event in the
|
|
||||||
// proper location within our time series.
|
|
||||||
nextSettleSeqNo, err := settleIndex.NextSequence()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var seqNoBytes [8]byte
|
|
||||||
byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo)
|
|
||||||
if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
invoice.Terms.State = ContractSettled
|
|
||||||
invoice.SettleDate = now
|
|
||||||
invoice.SettleIndex = nextSettleSeqNo
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -1,316 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// nodeInfoBucket stores metadata pertaining to nodes that we've had
|
|
||||||
// direct channel-based correspondence with. This bucket allows one to
|
|
||||||
// query for all open channels pertaining to the node by exploring each
|
|
||||||
// node's sub-bucket within the openChanBucket.
|
|
||||||
nodeInfoBucket = []byte("nib")
|
|
||||||
)
|
|
||||||
|
|
||||||
// LinkNode stores metadata related to node's that we have/had a direct
|
|
||||||
// channel open with. Information such as the Bitcoin network the node
|
|
||||||
// advertised, and its identity public key are also stored. Additionally, this
|
|
||||||
// struct and the bucket its stored within have store data similar to that of
|
|
||||||
// Bitcoin's addrmanager. The TCP address information stored within the struct
|
|
||||||
// can be used to establish persistent connections will all channel
|
|
||||||
// counterparties on daemon startup.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): also add current OnionKey plus rotation schedule?
|
|
||||||
// TODO(roasbeef): add bitfield for supported services
|
|
||||||
// * possibly add a wire.NetAddress type, type
|
|
||||||
type LinkNode struct {
|
|
||||||
// Network indicates the Bitcoin network that the LinkNode advertises
|
|
||||||
// for incoming channel creation.
|
|
||||||
Network wire.BitcoinNet
|
|
||||||
|
|
||||||
// IdentityPub is the node's current identity public key. Any
|
|
||||||
// channel/topology related information received by this node MUST be
|
|
||||||
// signed by this public key.
|
|
||||||
IdentityPub *btcec.PublicKey
|
|
||||||
|
|
||||||
// LastSeen tracks the last time this node was seen within the network.
|
|
||||||
// A node should be marked as seen if the daemon either is able to
|
|
||||||
// establish an outgoing connection to the node or receives a new
|
|
||||||
// incoming connection from the node. This timestamp (stored in unix
|
|
||||||
// epoch) may be used within a heuristic which aims to determine when a
|
|
||||||
// channel should be unilaterally closed due to inactivity.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): replace with block hash/height?
|
|
||||||
// * possibly add a time-value metric into the heuristic?
|
|
||||||
LastSeen time.Time
|
|
||||||
|
|
||||||
// Addresses is a list of IP address in which either we were able to
|
|
||||||
// reach the node over in the past, OR we received an incoming
|
|
||||||
// authenticated connection for the stored identity public key.
|
|
||||||
Addresses []net.Addr
|
|
||||||
|
|
||||||
db *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLinkNode creates a new LinkNode from the provided parameters, which is
|
|
||||||
// backed by an instance of channeldb.
|
|
||||||
func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey,
|
|
||||||
addrs ...net.Addr) *LinkNode {
|
|
||||||
|
|
||||||
return &LinkNode{
|
|
||||||
Network: bitNet,
|
|
||||||
IdentityPub: pub,
|
|
||||||
LastSeen: time.Now(),
|
|
||||||
Addresses: addrs,
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateLastSeen updates the last time this node was directly encountered on
|
|
||||||
// the Lightning Network.
|
|
||||||
func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error {
|
|
||||||
l.LastSeen = lastSeen
|
|
||||||
|
|
||||||
return l.Sync()
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAddress appends the specified TCP address to the list of known addresses
|
|
||||||
// this node is/was known to be reachable at.
|
|
||||||
func (l *LinkNode) AddAddress(addr net.Addr) error {
|
|
||||||
for _, a := range l.Addresses {
|
|
||||||
if a.String() == addr.String() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Addresses = append(l.Addresses, addr)
|
|
||||||
|
|
||||||
return l.Sync()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync performs a full database sync which writes the current up-to-date data
|
|
||||||
// within the struct to the database.
|
|
||||||
func (l *LinkNode) Sync() error {
|
|
||||||
|
|
||||||
// Finally update the database by storing the link node and updating
|
|
||||||
// any relevant indexes.
|
|
||||||
return l.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
|
||||||
if nodeMetaBucket == nil {
|
|
||||||
return ErrLinkNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return putLinkNode(nodeMetaBucket, l)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// putLinkNode serializes then writes the encoded version of the passed link
|
|
||||||
// node into the nodeMetaBucket. This function is provided in order to allow
|
|
||||||
// the ability to re-use a database transaction across many operations.
|
|
||||||
func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error {
|
|
||||||
// First serialize the LinkNode into its raw-bytes encoding.
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializeLinkNode(&b, l); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally insert the link-node into the node metadata bucket keyed
|
|
||||||
// according to the its pubkey serialized in compressed form.
|
|
||||||
nodePub := l.IdentityPub.SerializeCompressed()
|
|
||||||
return nodeMetaBucket.Put(nodePub, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteLinkNode removes the link node with the given identity from the
|
|
||||||
// database.
|
|
||||||
func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error {
|
|
||||||
return db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
return db.deleteLinkNode(tx, identity)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error {
|
|
||||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
|
||||||
if nodeMetaBucket == nil {
|
|
||||||
return ErrLinkNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey := identity.SerializeCompressed()
|
|
||||||
return nodeMetaBucket.Delete(pubKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchLinkNode attempts to lookup the data for a LinkNode based on a target
|
|
||||||
// identity public key. If a particular LinkNode for the passed identity public
|
|
||||||
// key cannot be found, then ErrNodeNotFound if returned.
|
|
||||||
func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) {
|
|
||||||
var linkNode *LinkNode
|
|
||||||
err := db.View(func(tx *bbolt.Tx) error {
|
|
||||||
node, err := fetchLinkNode(tx, identity)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
linkNode = node
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return linkNode, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) {
|
|
||||||
// First fetch the bucket for storing node metadata, bailing out early
|
|
||||||
// if it hasn't been created yet.
|
|
||||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
|
||||||
if nodeMetaBucket == nil {
|
|
||||||
return nil, ErrLinkNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a link node for that particular public key cannot be located,
|
|
||||||
// then exit early with an ErrNodeNotFound.
|
|
||||||
pubKey := targetPub.SerializeCompressed()
|
|
||||||
nodeBytes := nodeMetaBucket.Get(pubKey)
|
|
||||||
if nodeBytes == nil {
|
|
||||||
return nil, ErrNodeNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, decode and allocate a fresh LinkNode object to be returned
|
|
||||||
// to the caller.
|
|
||||||
nodeReader := bytes.NewReader(nodeBytes)
|
|
||||||
return deserializeLinkNode(nodeReader)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(roasbeef): update link node addrs in server upon connection
|
|
||||||
|
|
||||||
// FetchAllLinkNodes starts a new database transaction to fetch all nodes with
|
|
||||||
// whom we have active channels with.
|
|
||||||
func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) {
|
|
||||||
var linkNodes []*LinkNode
|
|
||||||
err := db.View(func(tx *bbolt.Tx) error {
|
|
||||||
nodes, err := db.fetchAllLinkNodes(tx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
linkNodes = nodes
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return linkNodes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchAllLinkNodes uses an existing database transaction to fetch all nodes
|
|
||||||
// with whom we have active channels with.
|
|
||||||
func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) {
|
|
||||||
nodeMetaBucket := tx.Bucket(nodeInfoBucket)
|
|
||||||
if nodeMetaBucket == nil {
|
|
||||||
return nil, ErrLinkNodesNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
var linkNodes []*LinkNode
|
|
||||||
err := nodeMetaBucket.ForEach(func(k, v []byte) error {
|
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nodeReader := bytes.NewReader(v)
|
|
||||||
linkNode, err := deserializeLinkNode(nodeReader)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
linkNodes = append(linkNodes, linkNode)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return linkNodes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func serializeLinkNode(w io.Writer, l *LinkNode) error {
|
|
||||||
var buf [8]byte
|
|
||||||
|
|
||||||
byteOrder.PutUint32(buf[:4], uint32(l.Network))
|
|
||||||
if _, err := w.Write(buf[:4]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serializedID := l.IdentityPub.SerializeCompressed()
|
|
||||||
if _, err := w.Write(serializedID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
seenUnix := uint64(l.LastSeen.Unix())
|
|
||||||
byteOrder.PutUint64(buf[:], seenUnix)
|
|
||||||
if _, err := w.Write(buf[:]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
numAddrs := uint32(len(l.Addresses))
|
|
||||||
byteOrder.PutUint32(buf[:4], numAddrs)
|
|
||||||
if _, err := w.Write(buf[:4]); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, addr := range l.Addresses {
|
|
||||||
if err := serializeAddr(w, addr); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func deserializeLinkNode(r io.Reader) (*LinkNode, error) {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
buf [8]byte
|
|
||||||
)
|
|
||||||
|
|
||||||
node := &LinkNode{}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, buf[:4]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4]))
|
|
||||||
|
|
||||||
var pub [33]byte
|
|
||||||
if _, err := io.ReadFull(r, pub[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, buf[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0)
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(r, buf[:4]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
numAddrs := byteOrder.Uint32(buf[:4])
|
|
||||||
|
|
||||||
node.Addresses = make([]net.Addr, numAddrs)
|
|
||||||
for i := uint32(0); i < numAddrs; i++ {
|
|
||||||
addr, err := deserializeAddr(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
node.Addresses[i] = addr
|
|
||||||
}
|
|
||||||
|
|
||||||
return node, nil
|
|
||||||
}
|
|
@ -1,140 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
|
||||||
"github.com/btcsuite/btcd/wire"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLinkNodeEncodeDecode(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
// First we'll create some initial data to use for populating our test
|
|
||||||
// LinkNode instances.
|
|
||||||
_, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
|
||||||
_, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:])
|
|
||||||
addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test addr: %v", err)
|
|
||||||
}
|
|
||||||
addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create test addr: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create two fresh link node instances with the above dummy data, then
|
|
||||||
// fully sync both instances to disk.
|
|
||||||
node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1)
|
|
||||||
node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2)
|
|
||||||
if err := node1.Sync(); err != nil {
|
|
||||||
t.Fatalf("unable to sync node: %v", err)
|
|
||||||
}
|
|
||||||
if err := node2.Sync(); err != nil {
|
|
||||||
t.Fatalf("unable to sync node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch all current link nodes from the database, they should exactly
|
|
||||||
// match the two created above.
|
|
||||||
originalNodes := []*LinkNode{node2, node1}
|
|
||||||
linkNodes, err := cdb.FetchAllLinkNodes()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch nodes: %v", err)
|
|
||||||
}
|
|
||||||
for i, node := range linkNodes {
|
|
||||||
if originalNodes[i].Network != node.Network {
|
|
||||||
t.Fatalf("node networks don't match: expected %v, got %v",
|
|
||||||
originalNodes[i].Network, node.Network)
|
|
||||||
}
|
|
||||||
|
|
||||||
originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed()
|
|
||||||
dbPubkey := node.IdentityPub.SerializeCompressed()
|
|
||||||
if !bytes.Equal(originalPubkey, dbPubkey) {
|
|
||||||
t.Fatalf("node pubkeys don't match: expected %x, got %x",
|
|
||||||
originalPubkey, dbPubkey)
|
|
||||||
}
|
|
||||||
if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() {
|
|
||||||
t.Fatalf("last seen timestamps don't match: expected %v got %v",
|
|
||||||
originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix())
|
|
||||||
}
|
|
||||||
if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() {
|
|
||||||
t.Fatalf("addresses don't match: expected %v, got %v",
|
|
||||||
originalNodes[i].Addresses, node.Addresses)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll exercise the methods to append additional IP
|
|
||||||
// addresses, and also to update the last seen time.
|
|
||||||
if err := node1.UpdateLastSeen(time.Now()); err != nil {
|
|
||||||
t.Fatalf("unable to update last seen: %v", err)
|
|
||||||
}
|
|
||||||
if err := node1.AddAddress(addr2); err != nil {
|
|
||||||
t.Fatalf("unable to update addr: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch the same node from the database according to its public key.
|
|
||||||
node1DB, err := cdb.FetchLinkNode(pub1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to find node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Both the last seen timestamp and the list of reachable addresses for
|
|
||||||
// the node should be updated.
|
|
||||||
if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() {
|
|
||||||
t.Fatalf("last seen timestamps don't match: expected %v got %v",
|
|
||||||
node1.LastSeen.Unix(), node1DB.LastSeen.Unix())
|
|
||||||
}
|
|
||||||
if len(node1DB.Addresses) != 2 {
|
|
||||||
t.Fatalf("wrong length for node1 addresses: expected %v, got %v",
|
|
||||||
2, len(node1DB.Addresses))
|
|
||||||
}
|
|
||||||
if node1DB.Addresses[0].String() != addr1.String() {
|
|
||||||
t.Fatalf("wrong address for node: expected %v, got %v",
|
|
||||||
addr1.String(), node1DB.Addresses[0].String())
|
|
||||||
}
|
|
||||||
if node1DB.Addresses[1].String() != addr2.String() {
|
|
||||||
t.Fatalf("wrong address for node: expected %v, got %v",
|
|
||||||
addr2.String(), node1DB.Addresses[1].String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteLinkNode(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
_, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:])
|
|
||||||
addr := &net.TCPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1"),
|
|
||||||
Port: 1337,
|
|
||||||
}
|
|
||||||
linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr)
|
|
||||||
if err := linkNode.Sync(); err != nil {
|
|
||||||
t.Fatalf("unable to write link node to db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := cdb.FetchLinkNode(pubKey); err != nil {
|
|
||||||
t.Fatalf("unable to find link node: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cdb.DeleteLinkNode(pubKey); err != nil {
|
|
||||||
t.Fatalf("unable to delete link node from db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := cdb.FetchLinkNode(pubKey); err == nil {
|
|
||||||
t.Fatal("should not have found link node in db, but did")
|
|
||||||
}
|
|
||||||
}
|
|
@ -39,24 +39,3 @@ func DefaultOptions() Options {
|
|||||||
|
|
||||||
// OptionModifier is a function signature for modifying the default Options.
|
// OptionModifier is a function signature for modifying the default Options.
|
||||||
type OptionModifier func(*Options)
|
type OptionModifier func(*Options)
|
||||||
|
|
||||||
// OptionSetRejectCacheSize sets the RejectCacheSize to n.
|
|
||||||
func OptionSetRejectCacheSize(n int) OptionModifier {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.RejectCacheSize = n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OptionSetChannelCacheSize sets the ChannelCacheSize to n.
|
|
||||||
func OptionSetChannelCacheSize(n int) OptionModifier {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.ChannelCacheSize = n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OptionSetSyncFreelist allows the database to sync its freelist.
|
|
||||||
func OptionSetSyncFreelist(b bool) OptionModifier {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.NoFreelistSync = !b
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,373 +1,9 @@
|
|||||||
package migration_01_to_11
|
package migration_01_to_11
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/coreos/bbolt"
|
"github.com/coreos/bbolt"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrAlreadyPaid signals we have already paid this payment hash.
|
|
||||||
ErrAlreadyPaid = errors.New("invoice is already paid")
|
|
||||||
|
|
||||||
// ErrPaymentInFlight signals that payment for this payment hash is
|
|
||||||
// already "in flight" on the network.
|
|
||||||
ErrPaymentInFlight = errors.New("payment is in transition")
|
|
||||||
|
|
||||||
// ErrPaymentNotInitiated is returned if payment wasn't initiated in
|
|
||||||
// switch.
|
|
||||||
ErrPaymentNotInitiated = errors.New("payment isn't initiated")
|
|
||||||
|
|
||||||
// ErrPaymentAlreadySucceeded is returned in the event we attempt to
|
|
||||||
// change the status of a payment already succeeded.
|
|
||||||
ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded")
|
|
||||||
|
|
||||||
// ErrPaymentAlreadyFailed is returned in the event we attempt to
|
|
||||||
// re-fail a failed payment.
|
|
||||||
ErrPaymentAlreadyFailed = errors.New("payment has already failed")
|
|
||||||
|
|
||||||
// ErrUnknownPaymentStatus is returned when we do not recognize the
|
|
||||||
// existing state of a payment.
|
|
||||||
ErrUnknownPaymentStatus = errors.New("unknown payment status")
|
|
||||||
|
|
||||||
// errNoAttemptInfo is returned when no attempt info is stored yet.
|
|
||||||
errNoAttemptInfo = errors.New("unable to find attempt info for " +
|
|
||||||
"inflight payment")
|
|
||||||
)
|
|
||||||
|
|
||||||
// PaymentControl implements persistence for payments and payment attempts.
|
|
||||||
type PaymentControl struct {
|
|
||||||
db *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPaymentControl creates a new instance of the PaymentControl.
|
|
||||||
func NewPaymentControl(db *DB) *PaymentControl {
|
|
||||||
return &PaymentControl{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitPayment checks or records the given PaymentCreationInfo with the DB,
|
|
||||||
// making sure it does not already exist as an in-flight payment. Then this
|
|
||||||
// method returns successfully, the payment is guranteeed to be in the InFlight
|
|
||||||
// state.
|
|
||||||
func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash,
|
|
||||||
info *PaymentCreationInfo) error {
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializePaymentCreationInfo(&b, info); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
infoBytes := b.Bytes()
|
|
||||||
|
|
||||||
var updateErr error
|
|
||||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
// Reset the update error, to avoid carrying over an error
|
|
||||||
// from a previous execution of the batched db transaction.
|
|
||||||
updateErr = nil
|
|
||||||
|
|
||||||
bucket, err := createPaymentBucket(tx, paymentHash)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the existing status of this payment, if any.
|
|
||||||
paymentStatus := fetchPaymentStatus(bucket)
|
|
||||||
|
|
||||||
switch paymentStatus {
|
|
||||||
|
|
||||||
// We allow retrying failed payments.
|
|
||||||
case StatusFailed:
|
|
||||||
|
|
||||||
// This is a new payment that is being initialized for the
|
|
||||||
// first time.
|
|
||||||
case StatusUnknown:
|
|
||||||
|
|
||||||
// We already have an InFlight payment on the network. We will
|
|
||||||
// disallow any new payments.
|
|
||||||
case StatusInFlight:
|
|
||||||
updateErr = ErrPaymentInFlight
|
|
||||||
return nil
|
|
||||||
|
|
||||||
// We've already succeeded a payment to this payment hash,
|
|
||||||
// forbid the switch from sending another.
|
|
||||||
case StatusSucceeded:
|
|
||||||
updateErr = ErrAlreadyPaid
|
|
||||||
return nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
updateErr = ErrUnknownPaymentStatus
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Obtain a new sequence number for this payment. This is used
|
|
||||||
// to sort the payments in order of creation, and also acts as
|
|
||||||
// a unique identifier for each payment.
|
|
||||||
sequenceNum, err := nextPaymentSequence(tx)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = bucket.Put(paymentSequenceKey, sequenceNum)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the payment info to the bucket, which contains the
|
|
||||||
// static information for this payment
|
|
||||||
err = bucket.Put(paymentCreationInfoKey, infoBytes)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll delete any lingering attempt info to start with, in
|
|
||||||
// case we are initializing a payment that was attempted
|
|
||||||
// earlier, but left in a state where we could retry.
|
|
||||||
err = bucket.Delete(paymentAttemptInfoKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also delete any lingering failure info now that we are
|
|
||||||
// re-attempting.
|
|
||||||
return bucket.Delete(paymentFailInfoKey)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return updateErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterAttempt atomically records the provided PaymentAttemptInfo to the
|
|
||||||
// DB.
|
|
||||||
func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash,
|
|
||||||
attempt *PaymentAttemptInfo) error {
|
|
||||||
|
|
||||||
// Serialize the information before opening the db transaction.
|
|
||||||
var a bytes.Buffer
|
|
||||||
if err := serializePaymentAttemptInfo(&a, attempt); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
attemptBytes := a.Bytes()
|
|
||||||
|
|
||||||
var updateErr error
|
|
||||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
// Reset the update error, to avoid carrying over an error
|
|
||||||
// from a previous execution of the batched db transaction.
|
|
||||||
updateErr = nil
|
|
||||||
|
|
||||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
|
||||||
if err == ErrPaymentNotInitiated {
|
|
||||||
updateErr = ErrPaymentNotInitiated
|
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We can only register attempts for payments that are
|
|
||||||
// in-flight.
|
|
||||||
if err := ensureInFlight(bucket); err != nil {
|
|
||||||
updateErr = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the payment attempt to the payments bucket.
|
|
||||||
return bucket.Put(paymentAttemptInfoKey, attemptBytes)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return updateErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success transitions a payment into the Succeeded state. After invoking this
|
|
||||||
// method, InitPayment should always return an error to prevent us from making
|
|
||||||
// duplicate payments to the same payment hash. The provided preimage is
|
|
||||||
// atomically saved to the DB for record keeping.
|
|
||||||
func (p *PaymentControl) Success(paymentHash lntypes.Hash,
|
|
||||||
preimage lntypes.Preimage) (*route.Route, error) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
updateErr error
|
|
||||||
route *route.Route
|
|
||||||
)
|
|
||||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
// Reset the update error, to avoid carrying over an error
|
|
||||||
// from a previous execution of the batched db transaction.
|
|
||||||
updateErr = nil
|
|
||||||
|
|
||||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
|
||||||
if err == ErrPaymentNotInitiated {
|
|
||||||
updateErr = ErrPaymentNotInitiated
|
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We can only mark in-flight payments as succeeded.
|
|
||||||
if err := ensureInFlight(bucket); err != nil {
|
|
||||||
updateErr = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record the successful payment info atomically to the
|
|
||||||
// payments record.
|
|
||||||
err = bucket.Put(paymentSettleInfoKey, preimage[:])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve attempt info for the notification.
|
|
||||||
attempt, err := fetchPaymentAttempt(bucket)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
route = &attempt.Route
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return route, updateErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fail transitions a payment into the Failed state, and records the reason the
|
|
||||||
// payment failed. After invoking this method, InitPayment should return nil on
|
|
||||||
// its next call for this payment hash, allowing the switch to make a
|
|
||||||
// subsequent payment.
|
|
||||||
func (p *PaymentControl) Fail(paymentHash lntypes.Hash,
|
|
||||||
reason FailureReason) (*route.Route, error) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
updateErr error
|
|
||||||
route *route.Route
|
|
||||||
)
|
|
||||||
err := p.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
// Reset the update error, to avoid carrying over an error
|
|
||||||
// from a previous execution of the batched db transaction.
|
|
||||||
updateErr = nil
|
|
||||||
|
|
||||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
|
||||||
if err == ErrPaymentNotInitiated {
|
|
||||||
updateErr = ErrPaymentNotInitiated
|
|
||||||
return nil
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// We can only mark in-flight payments as failed.
|
|
||||||
if err := ensureInFlight(bucket); err != nil {
|
|
||||||
updateErr = err
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put the failure reason in the bucket for record keeping.
|
|
||||||
v := []byte{byte(reason)}
|
|
||||||
err = bucket.Put(paymentFailInfoKey, v)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve attempt info for the notification, if available.
|
|
||||||
attempt, err := fetchPaymentAttempt(bucket)
|
|
||||||
if err != nil && err != errNoAttemptInfo {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err != errNoAttemptInfo {
|
|
||||||
route = &attempt.Route
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return route, updateErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchPayment returns information about a payment from the database.
|
|
||||||
func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) (
|
|
||||||
*Payment, error) {
|
|
||||||
|
|
||||||
var payment *Payment
|
|
||||||
err := p.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
bucket, err := fetchPaymentBucket(tx, paymentHash)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
payment, err = fetchPayment(bucket)
|
|
||||||
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return payment, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createPaymentBucket creates or fetches the sub-bucket assigned to this
|
|
||||||
// payment hash.
|
|
||||||
func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) (
|
|
||||||
*bbolt.Bucket, error) {
|
|
||||||
|
|
||||||
payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return payments.CreateBucketIfNotExists(paymentHash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If
|
|
||||||
// the bucket does not exist, it returns ErrPaymentNotInitiated.
|
|
||||||
func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) (
|
|
||||||
*bbolt.Bucket, error) {
|
|
||||||
|
|
||||||
payments := tx.Bucket(paymentsRootBucket)
|
|
||||||
if payments == nil {
|
|
||||||
return nil, ErrPaymentNotInitiated
|
|
||||||
}
|
|
||||||
|
|
||||||
bucket := payments.Bucket(paymentHash[:])
|
|
||||||
if bucket == nil {
|
|
||||||
return nil, ErrPaymentNotInitiated
|
|
||||||
}
|
|
||||||
|
|
||||||
return bucket, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// nextPaymentSequence returns the next sequence number to store for a new
|
|
||||||
// payment.
|
|
||||||
func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) {
|
|
||||||
payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
seq, err := payments.NextSequence()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b := make([]byte, 8)
|
|
||||||
binary.BigEndian.PutUint64(b, seq)
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchPaymentStatus fetches the payment status of the payment. If the payment
|
// fetchPaymentStatus fetches the payment status of the payment. If the payment
|
||||||
// isn't found, it will default to "StatusUnknown".
|
// isn't found, it will default to "StatusUnknown".
|
||||||
func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
|
func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
|
||||||
@ -385,113 +21,3 @@ func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus {
|
|||||||
|
|
||||||
return StatusUnknown
|
return StatusUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureInFlight checks whether the payment found in the given bucket has
|
|
||||||
// status InFlight, and returns an error otherwise. This should be used to
|
|
||||||
// ensure we only mark in-flight payments as succeeded or failed.
|
|
||||||
func ensureInFlight(bucket *bbolt.Bucket) error {
|
|
||||||
paymentStatus := fetchPaymentStatus(bucket)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
|
|
||||||
// The payment was indeed InFlight, return.
|
|
||||||
case paymentStatus == StatusInFlight:
|
|
||||||
return nil
|
|
||||||
|
|
||||||
// Our records show the payment as unknown, meaning it never
|
|
||||||
// should have left the switch.
|
|
||||||
case paymentStatus == StatusUnknown:
|
|
||||||
return ErrPaymentNotInitiated
|
|
||||||
|
|
||||||
// The payment succeeded previously.
|
|
||||||
case paymentStatus == StatusSucceeded:
|
|
||||||
return ErrPaymentAlreadySucceeded
|
|
||||||
|
|
||||||
// The payment was already failed.
|
|
||||||
case paymentStatus == StatusFailed:
|
|
||||||
return ErrPaymentAlreadyFailed
|
|
||||||
|
|
||||||
default:
|
|
||||||
return ErrUnknownPaymentStatus
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchPaymentAttempt fetches the payment attempt from the bucket.
|
|
||||||
func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) {
|
|
||||||
attemptData := bucket.Get(paymentAttemptInfoKey)
|
|
||||||
if attemptData == nil {
|
|
||||||
return nil, errNoAttemptInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(attemptData)
|
|
||||||
return deserializePaymentAttemptInfo(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InFlightPayment is a wrapper around a payment that has status InFlight.
|
|
||||||
type InFlightPayment struct {
|
|
||||||
// Info is the PaymentCreationInfo of the in-flight payment.
|
|
||||||
Info *PaymentCreationInfo
|
|
||||||
|
|
||||||
// Attempt contains information about the last payment attempt that was
|
|
||||||
// made to this payment hash.
|
|
||||||
//
|
|
||||||
// NOTE: Might be nil.
|
|
||||||
Attempt *PaymentAttemptInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchInFlightPayments returns all payments with status InFlight.
|
|
||||||
func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) {
|
|
||||||
var inFlights []*InFlightPayment
|
|
||||||
err := p.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
payments := tx.Bucket(paymentsRootBucket)
|
|
||||||
if payments == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return payments.ForEach(func(k, _ []byte) error {
|
|
||||||
bucket := payments.Bucket(k)
|
|
||||||
if bucket == nil {
|
|
||||||
return fmt.Errorf("non bucket element")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the status is not InFlight, we can return early.
|
|
||||||
paymentStatus := fetchPaymentStatus(bucket)
|
|
||||||
if paymentStatus != StatusInFlight {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
inFlight = &InFlightPayment{}
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
// Get the CreationInfo.
|
|
||||||
b := bucket.Get(paymentCreationInfoKey)
|
|
||||||
if b == nil {
|
|
||||||
return fmt.Errorf("unable to find creation " +
|
|
||||||
"info for inflight payment")
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(b)
|
|
||||||
inFlight.Info, err = deserializePaymentCreationInfo(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now get the attempt info. It could be that there is
|
|
||||||
// no attempt info yet.
|
|
||||||
inFlight.Attempt, err = fetchPaymentAttempt(bucket)
|
|
||||||
if err != nil && err != errNoAttemptInfo {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
inFlights = append(inFlights, inFlight)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return inFlights, nil
|
|
||||||
}
|
|
||||||
|
@ -1,550 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/btcsuite/fastsha256"
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
func initDB() (*DB, error) {
|
|
||||||
tempPath, err := ioutil.TempDir("", "switchdb")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := Open(tempPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return db, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func genPreimage() ([32]byte, error) {
|
|
||||||
var preimage [32]byte
|
|
||||||
if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil {
|
|
||||||
return preimage, err
|
|
||||||
}
|
|
||||||
return preimage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo,
|
|
||||||
lntypes.Preimage, error) {
|
|
||||||
|
|
||||||
preimage, err := genPreimage()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, preimage, fmt.Errorf("unable to "+
|
|
||||||
"generate preimage: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rhash := fastsha256.Sum256(preimage[:])
|
|
||||||
return &PaymentCreationInfo{
|
|
||||||
PaymentHash: rhash,
|
|
||||||
Value: 1,
|
|
||||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
|
||||||
PaymentRequest: []byte("hola"),
|
|
||||||
},
|
|
||||||
&PaymentAttemptInfo{
|
|
||||||
PaymentID: 1,
|
|
||||||
SessionKey: priv,
|
|
||||||
Route: testRoute,
|
|
||||||
}, preimage, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPaymentControlSwitchFail checks that payment status returns to Failed
|
|
||||||
// status after failing, and that InitPayment allows another HTLC for the
|
|
||||||
// same payment hash.
|
|
||||||
func TestPaymentControlSwitchFail(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, err := initDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to init db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pControl := NewPaymentControl(db)
|
|
||||||
|
|
||||||
info, attempt, preimg, err := genInfo()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sends base htlc message which initiate StatusInFlight.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Fail the payment, which should moved it to Failed.
|
|
||||||
failReason := FailureReasonNoRoute
|
|
||||||
_, err = pControl.Fail(info.PaymentHash, failReason)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fail payment hash: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the status is indeed Failed.
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
|
||||||
&failReason,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Sends the htlc again, which should succeed since the prior payment
|
|
||||||
// failed.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Record a new attempt.
|
|
||||||
attempt.PaymentID = 2
|
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Verifies that status was changed to StatusSucceeded.
|
|
||||||
var route *route.Route
|
|
||||||
route, err = pControl.Success(info.PaymentHash, preimg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = assertRouteEqual(route, &attempt.Route)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected route returned: %v vs %v: %v",
|
|
||||||
spew.Sdump(attempt.Route), spew.Sdump(*route), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
|
||||||
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
|
|
||||||
|
|
||||||
// Attempt a final payment, which should now fail since the prior
|
|
||||||
// payment succeed.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != ErrAlreadyPaid {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPaymentControlSwitchDoubleSend checks the ability of payment control to
|
|
||||||
// prevent double sending of htlc message, when message is in StatusInFlight.
|
|
||||||
func TestPaymentControlSwitchDoubleSend(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, err := initDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to init db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pControl := NewPaymentControl(db)
|
|
||||||
|
|
||||||
info, attempt, preimg, err := genInfo()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sends base htlc message which initiate base status and move it to
|
|
||||||
// StatusInFlight and verifies that it was changed.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, nil, lntypes.Preimage{},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Try to initiate double sending of htlc message with the same
|
|
||||||
// payment hash, should result in error indicating that payment has
|
|
||||||
// already been sent.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != ErrPaymentInFlight {
|
|
||||||
t.Fatalf("payment control wrong behaviour: " +
|
|
||||||
"double sending must trigger ErrPaymentInFlight error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Record an attempt.
|
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, attempt, lntypes.Preimage{},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
|
|
||||||
// Sends base htlc message which initiate StatusInFlight.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != ErrPaymentInFlight {
|
|
||||||
t.Fatalf("payment control wrong behaviour: " +
|
|
||||||
"double sending must trigger ErrPaymentInFlight error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// After settling, the error should be ErrAlreadyPaid.
|
|
||||||
if _, err := pControl.Success(info.PaymentHash, preimg); err != nil {
|
|
||||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
|
||||||
}
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
|
||||||
assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil)
|
|
||||||
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != ErrAlreadyPaid {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPaymentControlSuccessesWithoutInFlight checks that the payment
|
|
||||||
// control will disallow calls to Success when no payment is in flight.
|
|
||||||
func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, err := initDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to init db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pControl := NewPaymentControl(db)
|
|
||||||
|
|
||||||
info, _, preimg, err := genInfo()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to complete the payment should fail.
|
|
||||||
_, err = pControl.Success(info.PaymentHash, preimg)
|
|
||||||
if err != ErrPaymentNotInitiated {
|
|
||||||
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{},
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPaymentControlFailsWithoutInFlight checks that a strict payment
|
|
||||||
// control will disallow calls to Fail when no payment is in flight.
|
|
||||||
func TestPaymentControlFailsWithoutInFlight(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, err := initDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to init db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pControl := NewPaymentControl(db)
|
|
||||||
|
|
||||||
info, _, _, err := genInfo()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calling Fail should return an error.
|
|
||||||
_, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute)
|
|
||||||
if err != ErrPaymentNotInitiated {
|
|
||||||
t.Fatalf("expected ErrPaymentNotInitiated, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only
|
|
||||||
// deletes payments from the database that are not in-flight.
|
|
||||||
func TestPaymentControlDeleteNonInFligt(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, err := initDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to init db: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pControl := NewPaymentControl(db)
|
|
||||||
|
|
||||||
payments := []struct {
|
|
||||||
failed bool
|
|
||||||
success bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
failed: true,
|
|
||||||
success: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
failed: false,
|
|
||||||
success: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
failed: false,
|
|
||||||
success: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range payments {
|
|
||||||
info, attempt, preimg, err := genInfo()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sends base htlc message which initiate StatusInFlight.
|
|
||||||
err = pControl.InitPayment(info.PaymentHash, info)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
err = pControl.RegisterAttempt(info.PaymentHash, attempt)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to send htlc message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.failed {
|
|
||||||
// Fail the payment, which should moved it to Failed.
|
|
||||||
failReason := FailureReasonNoRoute
|
|
||||||
_, err = pControl.Fail(info.PaymentHash, failReason)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fail payment hash: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the status is indeed Failed.
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusFailed)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, attempt,
|
|
||||||
lntypes.Preimage{}, &failReason,
|
|
||||||
)
|
|
||||||
} else if p.success {
|
|
||||||
// Verifies that status was changed to StatusSucceeded.
|
|
||||||
_, err := pControl.Success(info.PaymentHash, preimg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, attempt, preimg, nil,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight)
|
|
||||||
assertPaymentInfo(
|
|
||||||
t, db, info.PaymentHash, info, attempt,
|
|
||||||
lntypes.Preimage{}, nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete payments.
|
|
||||||
if err := db.DeletePayments(); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This should leave the in-flight payment.
|
|
||||||
dbPayments, err := db.FetchPayments()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dbPayments) != 1 {
|
|
||||||
t.Fatalf("expected one payment, got %d", len(dbPayments))
|
|
||||||
}
|
|
||||||
|
|
||||||
status := dbPayments[0].Status
|
|
||||||
if status != StatusInFlight {
|
|
||||||
t.Fatalf("expected in-fligth status, got %v", status)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertPaymentStatus(t *testing.T, db *DB,
|
|
||||||
hash [32]byte, expStatus PaymentStatus) {
|
|
||||||
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
var paymentStatus = StatusUnknown
|
|
||||||
err := db.View(func(tx *bbolt.Tx) error {
|
|
||||||
payments := tx.Bucket(paymentsRootBucket)
|
|
||||||
if payments == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
bucket := payments.Bucket(hash[:])
|
|
||||||
if bucket == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the existing status of this payment, if any.
|
|
||||||
paymentStatus = fetchPaymentStatus(bucket)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to fetch payment status: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if paymentStatus != expStatus {
|
|
||||||
t.Fatalf("payment status mismatch: expected %v, got %v",
|
|
||||||
expStatus, paymentStatus)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error {
|
|
||||||
b := bucket.Get(paymentCreationInfoKey)
|
|
||||||
switch {
|
|
||||||
case b == nil && c == nil:
|
|
||||||
return nil
|
|
||||||
case b == nil:
|
|
||||||
return fmt.Errorf("expected creation info not found")
|
|
||||||
case c == nil:
|
|
||||||
return fmt.Errorf("unexpected creation info found")
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(b)
|
|
||||||
c2, err := deserializePaymentCreationInfo(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(c, c2) {
|
|
||||||
return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v",
|
|
||||||
spew.Sdump(c), spew.Sdump(c2))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error {
|
|
||||||
b := bucket.Get(paymentAttemptInfoKey)
|
|
||||||
switch {
|
|
||||||
case b == nil && a == nil:
|
|
||||||
return nil
|
|
||||||
case b == nil:
|
|
||||||
return fmt.Errorf("expected attempt info not found")
|
|
||||||
case a == nil:
|
|
||||||
return fmt.Errorf("unexpected attempt info found")
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(b)
|
|
||||||
a2, err := deserializePaymentAttemptInfo(r)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return assertRouteEqual(&a.Route, &a2.Route)
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error {
|
|
||||||
zero := lntypes.Preimage{}
|
|
||||||
b := bucket.Get(paymentSettleInfoKey)
|
|
||||||
switch {
|
|
||||||
case b == nil && preimg == zero:
|
|
||||||
return nil
|
|
||||||
case b == nil:
|
|
||||||
return fmt.Errorf("expected preimage not found")
|
|
||||||
case preimg == zero:
|
|
||||||
return fmt.Errorf("unexpected preimage found")
|
|
||||||
}
|
|
||||||
|
|
||||||
var pre2 lntypes.Preimage
|
|
||||||
copy(pre2[:], b[:])
|
|
||||||
if preimg != pre2 {
|
|
||||||
return fmt.Errorf("Preimages don't match: %x vs %x",
|
|
||||||
preimg, pre2)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error {
|
|
||||||
b := bucket.Get(paymentFailInfoKey)
|
|
||||||
switch {
|
|
||||||
case b == nil && failReason == nil:
|
|
||||||
return nil
|
|
||||||
case b == nil:
|
|
||||||
return fmt.Errorf("expected fail info not found")
|
|
||||||
case failReason == nil:
|
|
||||||
return fmt.Errorf("unexpected fail info found")
|
|
||||||
}
|
|
||||||
|
|
||||||
failReason2 := FailureReason(b[0])
|
|
||||||
if *failReason != failReason2 {
|
|
||||||
return fmt.Errorf("Failure infos don't match: %v vs %v",
|
|
||||||
*failReason, failReason2)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash,
|
|
||||||
c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage,
|
|
||||||
f *FailureReason) {
|
|
||||||
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
err := db.View(func(tx *bbolt.Tx) error {
|
|
||||||
payments := tx.Bucket(paymentsRootBucket)
|
|
||||||
if payments == nil && c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if payments == nil {
|
|
||||||
return fmt.Errorf("sent payments not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
bucket := payments.Bucket(hash[:])
|
|
||||||
if bucket == nil && c == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if bucket == nil {
|
|
||||||
return fmt.Errorf("payment not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkPaymentCreationInfo(bucket, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkPaymentAttemptInfo(bucket, a); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkSettleInfo(bucket, s); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := checkFailInfo(bucket, f); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("assert payment info failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -375,48 +375,6 @@ func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) {
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePayments deletes all completed and failed payments from the DB.
|
|
||||||
func (db *DB) DeletePayments() error {
|
|
||||||
return db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
payments := tx.Bucket(paymentsRootBucket)
|
|
||||||
if payments == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var deleteBuckets [][]byte
|
|
||||||
err := payments.ForEach(func(k, _ []byte) error {
|
|
||||||
bucket := payments.Bucket(k)
|
|
||||||
if bucket == nil {
|
|
||||||
// We only expect sub-buckets to be found in
|
|
||||||
// this top-level bucket.
|
|
||||||
return fmt.Errorf("non bucket element in " +
|
|
||||||
"payments bucket")
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the status is InFlight, we cannot safely delete
|
|
||||||
// the payment information, so we return early.
|
|
||||||
paymentStatus := fetchPaymentStatus(bucket)
|
|
||||||
if paymentStatus == StatusInFlight {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
deleteBuckets = append(deleteBuckets, k)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, k := range deleteBuckets {
|
|
||||||
if err := payments.DeleteBucket(k); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
|
func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
|
||||||
var scratch [8]byte
|
var scratch [8]byte
|
||||||
|
|
||||||
|
@ -2,55 +2,17 @@ package migration_01_to_11
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
|
||||||
"github.com/lightningnetwork/lnd/tlv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
priv, _ = btcec.NewPrivateKey(btcec.S256())
|
priv, _ = btcec.NewPrivateKey(btcec.S256())
|
||||||
pub = priv.PubKey()
|
pub = priv.PubKey()
|
||||||
|
|
||||||
tlvBytes = []byte{1, 2, 3}
|
|
||||||
tlvEncoder = tlv.StubEncoder(tlvBytes)
|
|
||||||
testHop1 = &route.Hop{
|
|
||||||
PubKeyBytes: route.NewVertex(pub),
|
|
||||||
ChannelID: 12345,
|
|
||||||
OutgoingTimeLock: 111,
|
|
||||||
AmtToForward: 555,
|
|
||||||
TLVRecords: []tlv.Record{
|
|
||||||
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
|
|
||||||
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
testHop2 = &route.Hop{
|
|
||||||
PubKeyBytes: route.NewVertex(pub),
|
|
||||||
ChannelID: 12345,
|
|
||||||
OutgoingTimeLock: 111,
|
|
||||||
AmtToForward: 555,
|
|
||||||
LegacyPayload: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
testRoute = route.Route{
|
|
||||||
TotalTimeLock: 123,
|
|
||||||
TotalAmount: 1234567,
|
|
||||||
SourcePubKey: route.NewVertex(pub),
|
|
||||||
Hops: []*route.Hop{
|
|
||||||
testHop1,
|
|
||||||
testHop2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeFakePayment() *outgoingPayment {
|
func makeFakePayment() *outgoingPayment {
|
||||||
@ -81,27 +43,6 @@ func makeFakePayment() *outgoingPayment {
|
|||||||
return fakePayment
|
return fakePayment
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) {
|
|
||||||
var preimg lntypes.Preimage
|
|
||||||
copy(preimg[:], rev[:])
|
|
||||||
|
|
||||||
c := &PaymentCreationInfo{
|
|
||||||
PaymentHash: preimg.Hash(),
|
|
||||||
Value: 1000,
|
|
||||||
// Use single second precision to avoid false positive test
|
|
||||||
// failures due to the monotonic time component.
|
|
||||||
CreationDate: time.Unix(time.Now().Unix(), 0),
|
|
||||||
PaymentRequest: []byte(""),
|
|
||||||
}
|
|
||||||
|
|
||||||
a := &PaymentAttemptInfo{
|
|
||||||
PaymentID: 44,
|
|
||||||
SessionKey: priv,
|
|
||||||
Route: testRoute,
|
|
||||||
}
|
|
||||||
return c, a
|
|
||||||
}
|
|
||||||
|
|
||||||
// randomBytes creates random []byte with length in range [minLen, maxLen)
|
// randomBytes creates random []byte with length in range [minLen, maxLen)
|
||||||
func randomBytes(minLen, maxLen int) ([]byte, error) {
|
func randomBytes(minLen, maxLen int) ([]byte, error) {
|
||||||
randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen))
|
randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen))
|
||||||
@ -165,160 +106,3 @@ func makeRandomFakePayment() (*outgoingPayment, error) {
|
|||||||
|
|
||||||
return fakePayment, nil
|
return fakePayment, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentPaymentSerialization(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
c, s := makeFakeInfo()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := serializePaymentCreationInfo(&b, c); err != nil {
|
|
||||||
t.Fatalf("unable to serialize creation info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newCreationInfo, err := deserializePaymentCreationInfo(&b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to deserialize creation info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(c, newCreationInfo) {
|
|
||||||
t.Fatalf("Payments do not match after "+
|
|
||||||
"serialization/deserialization %v vs %v",
|
|
||||||
spew.Sdump(c), spew.Sdump(newCreationInfo),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.Reset()
|
|
||||||
if err := serializePaymentAttemptInfo(&b, s); err != nil {
|
|
||||||
t.Fatalf("unable to serialize info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
newAttemptInfo, err := deserializePaymentAttemptInfo(&b)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to deserialize info: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// First we verify all the records match up porperly, as they aren't
|
|
||||||
// able to be properly compared using reflect.DeepEqual.
|
|
||||||
err = assertRouteEqual(&s.Route, &newAttemptInfo.Route)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Routes do not match after "+
|
|
||||||
"serialization/deserialization: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear routes to allow DeepEqual to compare the remaining fields.
|
|
||||||
newAttemptInfo.Route = route.Route{}
|
|
||||||
s.Route = route.Route{}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(s, newAttemptInfo) {
|
|
||||||
s.SessionKey.Curve = nil
|
|
||||||
newAttemptInfo.SessionKey.Curve = nil
|
|
||||||
t.Fatalf("Payments do not match after "+
|
|
||||||
"serialization/deserialization %v vs %v",
|
|
||||||
spew.Sdump(s), spew.Sdump(newAttemptInfo),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertRouteEquals compares to routes for equality and returns an error if
|
|
||||||
// they are not equal.
|
|
||||||
func assertRouteEqual(a, b *route.Route) error {
|
|
||||||
err := assertRouteHopRecordsEqual(a, b)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TLV records have already been compared and need to be cleared to
|
|
||||||
// properly compare the remaining fields using DeepEqual.
|
|
||||||
copyRouteNoHops := func(r *route.Route) *route.Route {
|
|
||||||
copy := *r
|
|
||||||
copy.Hops = make([]*route.Hop, len(r.Hops))
|
|
||||||
for i, hop := range r.Hops {
|
|
||||||
hopCopy := *hop
|
|
||||||
hopCopy.TLVRecords = nil
|
|
||||||
copy.Hops[i] = &hopCopy
|
|
||||||
}
|
|
||||||
return ©
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) {
|
|
||||||
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
|
|
||||||
spew.Sdump(a), spew.Sdump(b))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
|
|
||||||
if len(r1.Hops) != len(r2.Hops) {
|
|
||||||
return errors.New("route hop count mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(r1.Hops); i++ {
|
|
||||||
records1 := r1.Hops[i].TLVRecords
|
|
||||||
records2 := r2.Hops[i].TLVRecords
|
|
||||||
if len(records1) != len(records2) {
|
|
||||||
return fmt.Errorf("route record count for hop %v "+
|
|
||||||
"mismatch", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
for j := 0; j < len(records1); j++ {
|
|
||||||
expectedRecord := records1[j]
|
|
||||||
newRecord := records2[j]
|
|
||||||
|
|
||||||
err := assertHopRecordsEqual(expectedRecord, newRecord)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route record mismatch: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
|
|
||||||
if h1.Type() != h2.Type() {
|
|
||||||
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
|
|
||||||
h2.Type())
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := h2.Encode(&b); err != nil {
|
|
||||||
return fmt.Errorf("unable to encode record: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !bytes.Equal(b.Bytes(), tlvBytes) {
|
|
||||||
return fmt.Errorf("wrong raw record: expected %x, got %x",
|
|
||||||
tlvBytes, b.Bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
if h1.Size() != h2.Size() {
|
|
||||||
return fmt.Errorf("wrong size: expected %v, "+
|
|
||||||
"got %v", h1.Size(), h2.Size())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouteSerialization(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if err := SerializeRoute(&b, testRoute); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(b.Bytes())
|
|
||||||
route2, err := DeserializeRoute(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// First we verify all the records match up porperly, as they aren't
|
|
||||||
// able to be properly compared using reflect.DeepEqual.
|
|
||||||
err = assertRouteEqual(&testRoute, &route2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("routes not equal: \n%v vs \n%v",
|
|
||||||
spew.Sdump(testRoute), spew.Sdump(route2))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,95 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
// rejectFlags is a compact representation of various metadata stored by the
|
|
||||||
// reject cache about a particular channel.
|
|
||||||
type rejectFlags uint8
|
|
||||||
|
|
||||||
const (
|
|
||||||
// rejectFlagExists is a flag indicating whether the channel exists,
|
|
||||||
// i.e. the channel is open and has a recent channel update. If this
|
|
||||||
// flag is not set, the channel is either a zombie or unknown.
|
|
||||||
rejectFlagExists rejectFlags = 1 << iota
|
|
||||||
|
|
||||||
// rejectFlagZombie is a flag indicating whether the channel is a
|
|
||||||
// zombie, i.e. the channel is open but has no recent channel updates.
|
|
||||||
rejectFlagZombie
|
|
||||||
)
|
|
||||||
|
|
||||||
// packRejectFlags computes the rejectFlags corresponding to the passed boolean
|
|
||||||
// values indicating whether the edge exists or is a zombie.
|
|
||||||
func packRejectFlags(exists, isZombie bool) rejectFlags {
|
|
||||||
var flags rejectFlags
|
|
||||||
if exists {
|
|
||||||
flags |= rejectFlagExists
|
|
||||||
}
|
|
||||||
if isZombie {
|
|
||||||
flags |= rejectFlagZombie
|
|
||||||
}
|
|
||||||
|
|
||||||
return flags
|
|
||||||
}
|
|
||||||
|
|
||||||
// unpack returns the booleans packed into the rejectFlags. The first indicates
|
|
||||||
// if the edge exists in our graph, the second indicates if the edge is a
|
|
||||||
// zombie.
|
|
||||||
func (f rejectFlags) unpack() (bool, bool) {
|
|
||||||
return f&rejectFlagExists == rejectFlagExists,
|
|
||||||
f&rejectFlagZombie == rejectFlagZombie
|
|
||||||
}
|
|
||||||
|
|
||||||
// rejectCacheEntry caches frequently accessed information about a channel,
|
|
||||||
// including the timestamps of its latest edge policies and whether or not the
|
|
||||||
// channel exists in the graph.
|
|
||||||
type rejectCacheEntry struct {
|
|
||||||
upd1Time int64
|
|
||||||
upd2Time int64
|
|
||||||
flags rejectFlags
|
|
||||||
}
|
|
||||||
|
|
||||||
// rejectCache is an in-memory cache used to improve the performance of
|
|
||||||
// HasChannelEdge. It caches information about the whether or channel exists, as
|
|
||||||
// well as the most recent timestamps for each policy (if they exists).
|
|
||||||
type rejectCache struct {
|
|
||||||
n int
|
|
||||||
edges map[uint64]rejectCacheEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
// newRejectCache creates a new rejectCache with maximum capacity of n entries.
|
|
||||||
func newRejectCache(n int) *rejectCache {
|
|
||||||
return &rejectCache{
|
|
||||||
n: n,
|
|
||||||
edges: make(map[uint64]rejectCacheEntry, n),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// get returns the entry from the cache for chanid, if it exists.
|
|
||||||
func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) {
|
|
||||||
entry, ok := c.edges[chanid]
|
|
||||||
return entry, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert adds the entry to the reject cache. If an entry for chanid already
|
|
||||||
// exists, it will be replaced with the new entry. If the entry doesn't exists,
|
|
||||||
// it will be inserted to the cache, performing a random eviction if the cache
|
|
||||||
// is at capacity.
|
|
||||||
func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) {
|
|
||||||
// If entry exists, replace it.
|
|
||||||
if _, ok := c.edges[chanid]; ok {
|
|
||||||
c.edges[chanid] = entry
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, evict an entry at random and insert.
|
|
||||||
if len(c.edges) == c.n {
|
|
||||||
for id := range c.edges {
|
|
||||||
delete(c.edges, id)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.edges[chanid] = entry
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove deletes an entry for chanid from the cache, if it exists.
|
|
||||||
func (c *rejectCache) remove(chanid uint64) {
|
|
||||||
delete(c.edges, chanid)
|
|
||||||
}
|
|
@ -1,107 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestRejectCache checks the behavior of the rejectCache with respect to insertion,
|
|
||||||
// eviction, and removal of cache entries.
|
|
||||||
func TestRejectCache(t *testing.T) {
|
|
||||||
const cacheSize = 100
|
|
||||||
|
|
||||||
// Create a new reject cache with the configured max size.
|
|
||||||
c := newRejectCache(cacheSize)
|
|
||||||
|
|
||||||
// As a sanity check, assert that querying the empty cache does not
|
|
||||||
// return an entry.
|
|
||||||
_, ok := c.get(0)
|
|
||||||
if ok {
|
|
||||||
t.Fatalf("reject cache should be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, fill up the cache entirely.
|
|
||||||
for i := uint64(0); i < cacheSize; i++ {
|
|
||||||
c.insert(i, entryForInt(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the cache has all of the entries just inserted, since no
|
|
||||||
// eviction should occur until we try to surpass the max size.
|
|
||||||
assertHasEntries(t, c, 0, cacheSize)
|
|
||||||
|
|
||||||
// Now, insert a new element that causes the cache to evict an element.
|
|
||||||
c.insert(cacheSize, entryForInt(cacheSize))
|
|
||||||
|
|
||||||
// Assert that the cache has this last entry, as the cache should evict
|
|
||||||
// some prior element and not the newly inserted one.
|
|
||||||
assertHasEntries(t, c, cacheSize, cacheSize)
|
|
||||||
|
|
||||||
// Iterate over all inserted elements and construct a set of the evicted
|
|
||||||
// elements.
|
|
||||||
evicted := make(map[uint64]struct{})
|
|
||||||
for i := uint64(0); i < cacheSize+1; i++ {
|
|
||||||
_, ok := c.get(i)
|
|
||||||
if !ok {
|
|
||||||
evicted[i] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that exactly one element has been evicted.
|
|
||||||
numEvicted := len(evicted)
|
|
||||||
if numEvicted != 1 {
|
|
||||||
t.Fatalf("expected one evicted entry, got: %d", numEvicted)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the highest item which initially caused the eviction and
|
|
||||||
// reinsert the element that was evicted prior.
|
|
||||||
c.remove(cacheSize)
|
|
||||||
for i := range evicted {
|
|
||||||
c.insert(i, entryForInt(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since the removal created an extra slot, the last insertion should
|
|
||||||
// not have caused an eviction and the entries for all channels in the
|
|
||||||
// original set that filled the cache should be present.
|
|
||||||
assertHasEntries(t, c, 0, cacheSize)
|
|
||||||
|
|
||||||
// Finally, reinsert the existing set back into the cache and test that
|
|
||||||
// the cache still has all the entries. If the randomized eviction were
|
|
||||||
// happening on inserts for existing cache items, we expect this to fail
|
|
||||||
// with high probability.
|
|
||||||
for i := uint64(0); i < cacheSize; i++ {
|
|
||||||
c.insert(i, entryForInt(i))
|
|
||||||
}
|
|
||||||
assertHasEntries(t, c, 0, cacheSize)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// assertHasEntries queries the reject cache for all channels in the range [start,
|
|
||||||
// end), asserting that they exist and their value matches the entry produced by
|
|
||||||
// entryForInt.
|
|
||||||
func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
for i := start; i < end; i++ {
|
|
||||||
entry, ok := c.get(i)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("reject cache should contain chan %d", i)
|
|
||||||
}
|
|
||||||
|
|
||||||
expEntry := entryForInt(i)
|
|
||||||
if !reflect.DeepEqual(entry, expEntry) {
|
|
||||||
t.Fatalf("entry mismatch, want: %v, got: %v",
|
|
||||||
expEntry, entry)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// entryForInt generates a unique rejectCacheEntry given an integer.
|
|
||||||
func entryForInt(i uint64) rejectCacheEntry {
|
|
||||||
exists := i%2 == 0
|
|
||||||
isZombie := i%3 == 0
|
|
||||||
return rejectCacheEntry{
|
|
||||||
upd1Time: int64(2 * i),
|
|
||||||
upd2Time: int64(2*i + 1),
|
|
||||||
flags: packRejectFlags(exists, isZombie),
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,251 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/go-errors/errors"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// waitingProofsBucketKey byte string name of the waiting proofs store.
|
|
||||||
waitingProofsBucketKey = []byte("waitingproofs")
|
|
||||||
|
|
||||||
// ErrWaitingProofNotFound is returned if waiting proofs haven't been
|
|
||||||
// found by db.
|
|
||||||
ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " +
|
|
||||||
"found")
|
|
||||||
|
|
||||||
// ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been
|
|
||||||
// found by db.
|
|
||||||
ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " +
|
|
||||||
"key already exist")
|
|
||||||
)
|
|
||||||
|
|
||||||
// WaitingProofStore is the bold db map-like storage for half announcement
|
|
||||||
// signatures. The one responsibility of this storage is to be able to
|
|
||||||
// retrieve waiting proofs after client restart.
|
|
||||||
type WaitingProofStore struct {
|
|
||||||
// cache is used in order to reduce the number of redundant get
|
|
||||||
// calls, when object isn't stored in it.
|
|
||||||
cache map[WaitingProofKey]struct{}
|
|
||||||
db *DB
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWaitingProofStore creates new instance of proofs storage.
|
|
||||||
func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) {
|
|
||||||
s := &WaitingProofStore{
|
|
||||||
db: db,
|
|
||||||
cache: make(map[WaitingProofKey]struct{}),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.ForAll(func(proof *WaitingProof) error {
|
|
||||||
s.cache[proof.Key()] = struct{}{}
|
|
||||||
return nil
|
|
||||||
}); err != nil && err != ErrWaitingProofNotFound {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds new waiting proof in the storage.
|
|
||||||
func (s *WaitingProofStore) Add(proof *WaitingProof) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
var err error
|
|
||||||
var b bytes.Buffer
|
|
||||||
|
|
||||||
// Get or create the bucket.
|
|
||||||
bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode the objects and place it in the bucket.
|
|
||||||
if err := proof.Encode(&b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
key := proof.Key()
|
|
||||||
|
|
||||||
return bucket.Put(key[:], b.Bytes())
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Knowing that the write succeeded, we can now update the in-memory
|
|
||||||
// cache with the proof's key.
|
|
||||||
s.cache[proof.Key()] = struct{}{}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove removes the proof from storage by its key.
|
|
||||||
func (s *WaitingProofStore) Remove(key WaitingProofKey) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
if _, ok := s.cache[key]; !ok {
|
|
||||||
return ErrWaitingProofNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.db.Update(func(tx *bbolt.Tx) error {
|
|
||||||
// Get or create the top bucket.
|
|
||||||
bucket := tx.Bucket(waitingProofsBucketKey)
|
|
||||||
if bucket == nil {
|
|
||||||
return ErrWaitingProofNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return bucket.Delete(key[:])
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Since the proof was successfully deleted from the store, we can now
|
|
||||||
// remove it from the in-memory cache.
|
|
||||||
delete(s.cache, key)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForAll iterates thought all waiting proofs and passing the waiting proof
|
|
||||||
// in the given callback.
|
|
||||||
func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error {
|
|
||||||
return s.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
bucket := tx.Bucket(waitingProofsBucketKey)
|
|
||||||
if bucket == nil {
|
|
||||||
return ErrWaitingProofNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over objects buckets.
|
|
||||||
return bucket.ForEach(func(k, v []byte) error {
|
|
||||||
// Skip buckets fields.
|
|
||||||
if v == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(v)
|
|
||||||
proof := &WaitingProof{}
|
|
||||||
if err := proof.Decode(r); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cb(proof)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get returns the object which corresponds to the given index.
|
|
||||||
func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) {
|
|
||||||
proof := &WaitingProof{}
|
|
||||||
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
if _, ok := s.cache[key]; !ok {
|
|
||||||
return nil, ErrWaitingProofNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
bucket := tx.Bucket(waitingProofsBucketKey)
|
|
||||||
if bucket == nil {
|
|
||||||
return ErrWaitingProofNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate over objects buckets.
|
|
||||||
v := bucket.Get(key[:])
|
|
||||||
if v == nil {
|
|
||||||
return ErrWaitingProofNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
r := bytes.NewReader(v)
|
|
||||||
return proof.Decode(r)
|
|
||||||
})
|
|
||||||
|
|
||||||
return proof, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitingProofKey is the proof key which uniquely identifies the waiting
|
|
||||||
// proof object. The goal of this key is distinguish the local and remote
|
|
||||||
// proof for the same channel id.
|
|
||||||
type WaitingProofKey [9]byte
|
|
||||||
|
|
||||||
// WaitingProof is the storable object, which encapsulate the half proof and
|
|
||||||
// the information about from which side this proof came. This structure is
|
|
||||||
// needed to make channel proof exchange persistent, so that after client
|
|
||||||
// restart we may receive remote/local half proof and process it.
|
|
||||||
type WaitingProof struct {
|
|
||||||
*lnwire.AnnounceSignatures
|
|
||||||
isRemote bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWaitingProof constructs a new waiting prof instance.
|
|
||||||
func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof {
|
|
||||||
return &WaitingProof{
|
|
||||||
AnnounceSignatures: proof,
|
|
||||||
isRemote: isRemote,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OppositeKey returns the key which uniquely identifies opposite waiting proof.
|
|
||||||
func (p *WaitingProof) OppositeKey() WaitingProofKey {
|
|
||||||
var key [9]byte
|
|
||||||
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
|
|
||||||
|
|
||||||
if !p.isRemote {
|
|
||||||
key[8] = 1
|
|
||||||
}
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
// Key returns the key which uniquely identifies waiting proof.
|
|
||||||
func (p *WaitingProof) Key() WaitingProofKey {
|
|
||||||
var key [9]byte
|
|
||||||
binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64())
|
|
||||||
|
|
||||||
if p.isRemote {
|
|
||||||
key[8] = 1
|
|
||||||
}
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode writes the internal representation of waiting proof in byte stream.
|
|
||||||
func (p *WaitingProof) Encode(w io.Writer) error {
|
|
||||||
if err := binary.Write(w, byteOrder, p.isRemote); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := p.AnnounceSignatures.Encode(w, 0); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads the data from the byte stream and initializes the
|
|
||||||
// waiting proof object with it.
|
|
||||||
func (p *WaitingProof) Decode(r io.Reader) error {
|
|
||||||
if err := binary.Read(r, byteOrder, &p.isRemote); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := &lnwire.AnnounceSignatures{}
|
|
||||||
if err := msg.Decode(r, 0); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
(*p).AnnounceSignatures = msg
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,59 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/go-errors/errors"
|
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestWaitingProofStore tests add/get/remove functions of the waiting proof
|
|
||||||
// storage.
|
|
||||||
func TestWaitingProofStore(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
db, cleanup, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to make test database: %s", err)
|
|
||||||
}
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{
|
|
||||||
NodeSignature: wireSig,
|
|
||||||
BitcoinSignature: wireSig,
|
|
||||||
})
|
|
||||||
|
|
||||||
store, err := NewWaitingProofStore(db)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to create the waiting proofs storage: %v",
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := store.Add(proof1); err != nil {
|
|
||||||
t.Fatalf("unable add proof to storage: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
proof2, err := store.Get(proof1.Key())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable retrieve proof from storage: %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(proof1, proof2) {
|
|
||||||
t.Fatal("wrong proof retrieved")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound {
|
|
||||||
t.Fatalf("proof shouldn't be found: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := store.Remove(proof1.Key()); err != nil {
|
|
||||||
t.Fatalf("unable remove proof from storage: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := store.ForAll(func(proof *WaitingProof) error {
|
|
||||||
return errors.New("storage should be empty")
|
|
||||||
}); err != nil && err != ErrWaitingProofNotFound {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,229 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/coreos/bbolt"
|
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// ErrNoWitnesses is an error that's returned when no new witnesses have
|
|
||||||
// been added to the WitnessCache.
|
|
||||||
ErrNoWitnesses = fmt.Errorf("no witnesses")
|
|
||||||
|
|
||||||
// ErrUnknownWitnessType is returned if a caller attempts to
|
|
||||||
ErrUnknownWitnessType = fmt.Errorf("unknown witness type")
|
|
||||||
)
|
|
||||||
|
|
||||||
// WitnessType is enum that denotes what "type" of witness is being
|
|
||||||
// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce
|
|
||||||
// any structure on added witnesses, we use this type to partition the
|
|
||||||
// witnesses on disk, and also to know how to map a witness to its look up key.
|
|
||||||
type WitnessType uint8
|
|
||||||
|
|
||||||
var (
|
|
||||||
// Sha256HashWitness is a witness that is simply the pre image to a
|
|
||||||
// hash image. In order to map to its key, we'll use sha256.
|
|
||||||
Sha256HashWitness WitnessType = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// toDBKey is a helper method that maps a witness type to the key that we'll
|
|
||||||
// use to store it within the database.
|
|
||||||
func (w WitnessType) toDBKey() ([]byte, error) {
|
|
||||||
switch w {
|
|
||||||
|
|
||||||
case Sha256HashWitness:
|
|
||||||
return []byte{byte(w)}, nil
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, ErrUnknownWitnessType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
// witnessBucketKey is the name of the bucket that we use to store all
|
|
||||||
// witnesses encountered. Within this bucket, we'll create a sub-bucket for
|
|
||||||
// each witness type.
|
|
||||||
witnessBucketKey = []byte("byte")
|
|
||||||
)
|
|
||||||
|
|
||||||
// WitnessCache is a persistent cache of all witnesses we've encountered on the
|
|
||||||
// network. In the case of multi-hop, multi-step contracts, a cache of all
|
|
||||||
// witnesses can be useful in the case of partial contract resolution. If
|
|
||||||
// negotiations break down, we may be forced to locate the witness for a
|
|
||||||
// portion of the contract on-chain. In this case, we'll then add that witness
|
|
||||||
// to the cache so the incoming contract can fully resolve witness.
|
|
||||||
// Additionally, as one MUST always use a unique witness on the network, we may
|
|
||||||
// use this cache to detect duplicate witnesses.
|
|
||||||
//
|
|
||||||
// TODO(roasbeef): need expiry policy?
|
|
||||||
// * encrypt?
|
|
||||||
type WitnessCache struct {
|
|
||||||
db *DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewWitnessCache returns a new instance of the witness cache.
|
|
||||||
func (d *DB) NewWitnessCache() *WitnessCache {
|
|
||||||
return &WitnessCache{
|
|
||||||
db: d,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// witnessEntry is a key-value struct that holds each key -> witness pair, used
|
|
||||||
// when inserting records into the cache.
|
|
||||||
type witnessEntry struct {
|
|
||||||
key []byte
|
|
||||||
witness []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddSha256Witnesses adds a batch of new sha256 preimages into the witness
|
|
||||||
// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the
|
|
||||||
// preimages' witness type.
|
|
||||||
func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error {
|
|
||||||
// Optimistically compute the preimages' hashes before attempting to
|
|
||||||
// start the db transaction.
|
|
||||||
entries := make([]witnessEntry, 0, len(preimages))
|
|
||||||
for i := range preimages {
|
|
||||||
hash := preimages[i].Hash()
|
|
||||||
entries = append(entries, witnessEntry{
|
|
||||||
key: hash[:],
|
|
||||||
witness: preimages[i][:],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.addWitnessEntries(Sha256HashWitness, entries)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addWitnessEntries inserts the witnessEntry key-value pairs into the cache,
|
|
||||||
// using the appropriate witness type to segment the namespace of possible
|
|
||||||
// witness types.
|
|
||||||
func (w *WitnessCache) addWitnessEntries(wType WitnessType,
|
|
||||||
entries []witnessEntry) error {
|
|
||||||
|
|
||||||
// Exit early if there are no witnesses to add.
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
witnessTypeBucketKey, err := wType.toDBKey()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
|
|
||||||
witnessTypeBucketKey,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, entry := range entries {
|
|
||||||
err = witnessTypeBucket.Put(entry.key, entry.witness)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If
|
|
||||||
// the witness isn't found, ErrNoWitnesses will be returned.
|
|
||||||
func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) {
|
|
||||||
witness, err := w.lookupWitness(Sha256HashWitness, hash[:])
|
|
||||||
if err != nil {
|
|
||||||
return lntypes.Preimage{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return lntypes.MakePreimage(witness)
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupWitness attempts to lookup a witness according to its type and also
|
|
||||||
// its witness key. In the case that the witness isn't found, ErrNoWitnesses
|
|
||||||
// will be returned.
|
|
||||||
func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) {
|
|
||||||
var witness []byte
|
|
||||||
err := w.db.View(func(tx *bbolt.Tx) error {
|
|
||||||
witnessBucket := tx.Bucket(witnessBucketKey)
|
|
||||||
if witnessBucket == nil {
|
|
||||||
return ErrNoWitnesses
|
|
||||||
}
|
|
||||||
|
|
||||||
witnessTypeBucketKey, err := wType.toDBKey()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey)
|
|
||||||
if witnessTypeBucket == nil {
|
|
||||||
return ErrNoWitnesses
|
|
||||||
}
|
|
||||||
|
|
||||||
dbWitness := witnessTypeBucket.Get(witnessKey)
|
|
||||||
if dbWitness == nil {
|
|
||||||
return ErrNoWitnesses
|
|
||||||
}
|
|
||||||
|
|
||||||
witness = make([]byte, len(dbWitness))
|
|
||||||
copy(witness[:], dbWitness)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return witness, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash.
|
|
||||||
func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error {
|
|
||||||
return w.deleteWitness(Sha256HashWitness, hash[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteWitness attempts to delete a particular witness from the database.
|
|
||||||
func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error {
|
|
||||||
return w.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
witnessTypeBucketKey, err := wType.toDBKey()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists(
|
|
||||||
witnessTypeBucketKey,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return witnessTypeBucket.Delete(witnessKey)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After
|
|
||||||
// this function return with a non-nil error,
|
|
||||||
func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error {
|
|
||||||
return w.db.Batch(func(tx *bbolt.Tx) error {
|
|
||||||
witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
witnessTypeBucketKey, err := wType.toDBKey()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return witnessBucket.DeleteBucket(witnessTypeBucketKey)
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,238 +0,0 @@
|
|||||||
package migration_01_to_11
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/sha256"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new
|
|
||||||
// sha256 preimages to the witness cache.
|
|
||||||
func TestWitnessCacheSha256Retrieval(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
wCache := cdb.NewWitnessCache()
|
|
||||||
|
|
||||||
// We'll be attempting to add then lookup two simple sha256 preimages
|
|
||||||
// within this test.
|
|
||||||
preimage1 := lntypes.Preimage(rev)
|
|
||||||
preimage2 := lntypes.Preimage(key)
|
|
||||||
|
|
||||||
preimages := []lntypes.Preimage{preimage1, preimage2}
|
|
||||||
hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()}
|
|
||||||
|
|
||||||
// First, we'll attempt to add the preimages to the database.
|
|
||||||
err = wCache.AddSha256Witnesses(preimages...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to add witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// With the preimages stored, we'll now attempt to look them up.
|
|
||||||
for i, hash := range hashes {
|
|
||||||
preimage := preimages[i]
|
|
||||||
|
|
||||||
// We should get back the *exact* same preimage as we originally
|
|
||||||
// stored.
|
|
||||||
dbPreimage, err := wCache.LookupSha256Witness(hash)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to look up witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if preimage != dbPreimage {
|
|
||||||
t.Fatalf("witnesses don't match: expected %x, got %x",
|
|
||||||
preimage[:], dbPreimage[:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestWitnessCacheSha256Deletion tests that we're able to delete a single
|
|
||||||
// sha256 preimage, and also a class of witnesses from the cache.
|
|
||||||
func TestWitnessCacheSha256Deletion(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
wCache := cdb.NewWitnessCache()
|
|
||||||
|
|
||||||
// We'll start by adding two preimages to the cache.
|
|
||||||
preimage1 := lntypes.Preimage(key)
|
|
||||||
hash1 := preimage1.Hash()
|
|
||||||
|
|
||||||
preimage2 := lntypes.Preimage(rev)
|
|
||||||
hash2 := preimage2.Hash()
|
|
||||||
|
|
||||||
if err := wCache.AddSha256Witnesses(preimage1); err != nil {
|
|
||||||
t.Fatalf("unable to add witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := wCache.AddSha256Witnesses(preimage2); err != nil {
|
|
||||||
t.Fatalf("unable to add witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now delete the first preimage. If we attempt to look it up, we
|
|
||||||
// should get ErrNoWitnesses.
|
|
||||||
err = wCache.DeleteSha256Witness(hash1)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to delete witness: %v", err)
|
|
||||||
}
|
|
||||||
_, err = wCache.LookupSha256Witness(hash1)
|
|
||||||
if err != ErrNoWitnesses {
|
|
||||||
t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Next, we'll attempt to delete the entire witness class itself. When
|
|
||||||
// we try to lookup the second preimage, we should again get
|
|
||||||
// ErrNoWitnesses.
|
|
||||||
if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil {
|
|
||||||
t.Fatalf("unable to delete witness class: %v", err)
|
|
||||||
}
|
|
||||||
_, err = wCache.LookupSha256Witness(hash2)
|
|
||||||
if err != ErrNoWitnesses {
|
|
||||||
t.Fatalf("expected ErrNoWitnesses instead got: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to
|
|
||||||
// query/add/delete an unknown witness.
|
|
||||||
func TestWitnessCacheUnknownWitness(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
wCache := cdb.NewWitnessCache()
|
|
||||||
|
|
||||||
// We'll attempt to add a new, undefined witness type to the database.
|
|
||||||
// We should get an error.
|
|
||||||
err = wCache.legacyAddWitnesses(234, key[:])
|
|
||||||
if err != ErrUnknownWitnessType {
|
|
||||||
t.Fatalf("expected ErrUnknownWitnessType, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves
|
|
||||||
// identically to the insertion via the generalized interface.
|
|
||||||
func TestAddSha256Witnesses(t *testing.T) {
|
|
||||||
cdb, cleanUp, err := makeTestDB()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to make test database: %v", err)
|
|
||||||
}
|
|
||||||
defer cleanUp()
|
|
||||||
|
|
||||||
wCache := cdb.NewWitnessCache()
|
|
||||||
|
|
||||||
// We'll start by adding a witnesses to the cache using the generic
|
|
||||||
// AddWitnesses method.
|
|
||||||
witness1 := rev[:]
|
|
||||||
preimage1 := lntypes.Preimage(rev)
|
|
||||||
hash1 := preimage1.Hash()
|
|
||||||
|
|
||||||
witness2 := key[:]
|
|
||||||
preimage2 := lntypes.Preimage(key)
|
|
||||||
hash2 := preimage2.Hash()
|
|
||||||
|
|
||||||
var (
|
|
||||||
witnesses = [][]byte{witness1, witness2}
|
|
||||||
preimages = []lntypes.Preimage{preimage1, preimage2}
|
|
||||||
hashes = []lntypes.Hash{hash1, hash2}
|
|
||||||
)
|
|
||||||
|
|
||||||
err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to add witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, hash := range hashes {
|
|
||||||
preimage := preimages[i]
|
|
||||||
|
|
||||||
dbPreimage, err := wCache.LookupSha256Witness(hash)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to lookup witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the retrieved witness matches the original.
|
|
||||||
if dbPreimage != preimage {
|
|
||||||
t.Fatalf("retrieved witness mismatch, want: %x, "+
|
|
||||||
"got: %x", preimage, dbPreimage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We'll now delete the witness, as we'll be reinserting it
|
|
||||||
// using the specialized AddSha256Witnesses method.
|
|
||||||
err = wCache.DeleteSha256Witness(hash)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to delete witness: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, add the same witnesses using the type-safe interface for
|
|
||||||
// lntypes.Preimages..
|
|
||||||
err = wCache.AddSha256Witnesses(preimages...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to add sha256 preimage: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finally, iterate over the keys and assert that the returned witnesses
|
|
||||||
// match the original witnesses. This asserts that the specialized
|
|
||||||
// insertion method behaves identically to the generalized interface.
|
|
||||||
for i, hash := range hashes {
|
|
||||||
preimage := preimages[i]
|
|
||||||
|
|
||||||
dbPreimage, err := wCache.LookupSha256Witness(hash)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to lookup witness: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assert that the retrieved witness matches the original.
|
|
||||||
if dbPreimage != preimage {
|
|
||||||
t.Fatalf("retrieved witness mismatch, want: %x, "+
|
|
||||||
"got: %x", preimage, dbPreimage)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacyAddWitnesses adds a batch of new witnesses of wType to the witness
|
|
||||||
// cache. The type of the witness will be used to map each witness to the key
|
|
||||||
// that will be used to look it up. All witnesses should be of the same
|
|
||||||
// WitnessType.
|
|
||||||
//
|
|
||||||
// NOTE: Previously this method exposed a generic interface for adding
|
|
||||||
// witnesses, which has since been deprecated in favor of a strongly typed
|
|
||||||
// interface for each witness class. We keep this method around to assert the
|
|
||||||
// correctness of specialized witness adding methods.
|
|
||||||
func (w *WitnessCache) legacyAddWitnesses(wType WitnessType,
|
|
||||||
witnesses ...[]byte) error {
|
|
||||||
|
|
||||||
// Optimistically compute the witness keys before attempting to start
|
|
||||||
// the db transaction.
|
|
||||||
entries := make([]witnessEntry, 0, len(witnesses))
|
|
||||||
for _, witness := range witnesses {
|
|
||||||
// Map each witness to its key by applying the appropriate
|
|
||||||
// transformation for the given witness type.
|
|
||||||
switch wType {
|
|
||||||
case Sha256HashWitness:
|
|
||||||
key := sha256.Sum256(witness)
|
|
||||||
entries = append(entries, witnessEntry{
|
|
||||||
key: key[:],
|
|
||||||
witness: witness,
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
return ErrUnknownWitnessType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return w.addWitnessEntries(wType, entries)
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user