diff --git a/channeldb/graph.go b/channeldb/graph.go index d9f21ef1..74c58908 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -3,8 +3,10 @@ package channeldb import ( "bytes" "encoding/binary" + "fmt" "image/color" "io" + "math" "net" "time" @@ -86,12 +88,17 @@ var ( // number of channels, etc. graphMetaBucket = []byte("graph-meta") - // pruneTipKey is a key within the above graphMetaBucket that stores - // the best known blockhash+height that the channel graph has been - // known to be pruned to. Once a new block is discovered, any channels - // that have been closed (by spending the outpoint) can safely be - // removed from the graph. - pruneTipKey = []byte("prune-tip") + // pruneLogBucket is a bucket within the graphMetaBucket that stores + // a mapping from the block height to the hash for the blocks used to + // prune the graph. + // Once a new block is discovered, any channels that have been closed + // (by spending the outpoint) can safely be removed from the graph, and + // the block is added to the prune log. We need to keep such a log for + // the case where a reorg happens, and we must "rewind" the state of the + // graph by removing channels that were previously confirmed. In such a + // case we'll remove all entries from the prune log with a block height + // that no longer exists. + pruneLogBucket = []byte("prune-log") edgeBloomKey = []byte("edge-bloom") nodeBloomKey = []byte("node-bloom") @@ -560,11 +567,12 @@ func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { } const ( - // pruneTipBytes is the total size of the value which stores the - // current prune tip of the graph. The prune tip indicates if the - // channel graph is in sync with the current UTXO state. The structure - // is: blockHash || blockHeight, taking 36 bytes total. - pruneTipBytes = 32 + 4 + // pruneTipBytes is the total size of the value which stores a prune + // entry of the graph in the prune log. The "prune tip" is the last + // entry in the prune log, and indicates if the channel graph is in + // sync with the current UTXO state. The structure of the value + // is: blockHash, taking 32 bytes total. + pruneTipBytes = 32 ) // PruneGraph prunes newly closed channels from the channel graph in response @@ -641,14 +649,21 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return err } - // With the graph pruned, update the current "prune tip" which - // can be used to check if the graph is fully synced with the - // current UTXO state. + pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + if err != nil { + return err + } + + // With the graph pruned, add a new entry to the prune log, + // which can be used to check if the graph is fully synced with + // the current UTXO state. + var blockHeightBytes [4]byte + byteOrder.PutUint32(blockHeightBytes[:], blockHeight) + var newTip [pruneTipBytes]byte copy(newTip[:], blockHash[:]) - byteOrder.PutUint32(newTip[32:], blockHeight) - return metaBucket.Put(pruneTipKey, newTip[:]) + return pruneBucket.Put(blockHeightBytes[:], newTip[:]) }) if err != nil { return nil, err @@ -657,15 +672,115 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, return chansClosed, nil } +// DisconnectBlockAtHeight is used to indicate that the block specified +// by the passed height has been disconnected from the main chain. This +// will "rewind" the graph back to the height below, deleting channels +// that are no longer confirmed from the graph. The prune log will be +// set to the last prune height valid for the remaining chain. +// Channels that were removed from the graph resulting from the +// disconnected block are returned. +func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo, + error) { + + // Every channel having a ShortChannelID starting at 'height' + // will no longer be confirmed. + startShortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + } + + // Delete everything after this height from the db. + endShortChanID := lnwire.ShortChannelID{ + BlockHeight: math.MaxUint32 & 0x00ffffff, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + // The block height will be the 3 first bytes of the channel IDs. + var chanIDStart [8]byte + byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) + var chanIDEnd [8]byte + byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) + + // Keep track of the channels that are removed from the graph. + var removedChans []*ChannelEdgeInfo + + if err := c.db.Update(func(tx *bolt.Tx) error { + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + + // Scan from chanIDStart to chanIDEnd, deleting every + // found edge. + cursor := edgeIndex.Cursor() + for k, v := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() { + + edgeInfoReader := bytes.NewReader(v) + edgeInfo, err := deserializeChanEdgeInfo(edgeInfoReader) + if err != nil { + return err + } + err = delChannelByEdge(edges, edgeIndex, chanIndex, + &edgeInfo.ChannelPoint) + if err != nil && err != ErrEdgeNotFound { + return err + } + + removedChans = append(removedChans, edgeInfo) + } + + // Delete all the entries in the prune log having a height + // greater or equal to the block disconnected. + metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) + if err != nil { + return err + } + + pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + if err != nil { + return err + } + + var pruneKeyStart [4]byte + byteOrder.PutUint32(pruneKeyStart[:], height) + + var pruneKeyEnd [4]byte + byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32) + + pruneCursor := pruneBucket.Cursor() + for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil && + bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() { + if err := pruneCursor.Delete(); err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + return removedChans, nil +} + // PruneTip returns the block height and hash of the latest block that has been // used to prune channels in the graph. Knowing the "prune tip" allows callers // to tell if the graph is currently in sync with the current best known UTXO // state. func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { var ( - currentTip [pruneTipBytes]byte - tipHash chainhash.Hash - tipHeight uint32 + tipHash chainhash.Hash + tipHeight uint32 ) err := c.db.View(func(tx *bolt.Tx) error { @@ -673,12 +788,24 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { if graphMeta == nil { return ErrGraphNotFound } - - tipBytes := graphMeta.Get(pruneTipKey) - if tipBytes == nil { + pruneBucket := graphMeta.Bucket(pruneLogBucket) + if pruneBucket == nil { return ErrGraphNeverPruned } - copy(currentTip[:], tipBytes) + + pruneCursor := pruneBucket.Cursor() + + // The prune key with the largest block height will be our + // prune tip. + k, v := pruneCursor.Last() + if k == nil { + return ErrGraphNeverPruned + } + + // Once we have the prune tip, the value will be the block hash, + // and the key the block height. + copy(tipHash[:], v[:]) + tipHeight = byteOrder.Uint32(k[:]) return nil }) @@ -686,11 +813,6 @@ func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { return nil, 0, err } - // Once we have the prune tip, the first 32 bytes are the block hash, - // with the latter 4 bytes being the block height. - copy(tipHash[:], currentTip[:32]) - tipHeight = byteOrder.Uint32(currentTip[32:]) - return &tipHash, tipHeight, nil } @@ -778,6 +900,10 @@ func delChannelByEdge(edges *bolt.Bucket, edgeIndex *bolt.Bucket, // the keys which house both of the directed edges for this // channel. nodeKeys := edgeIndex.Get(chanID) + if nodeKeys == nil { + return fmt.Errorf("could not find nodekeys for chanID %v", + chanID) + } // The edge key is of the format pubKey || chanID. First we // construct the latter half, populating the channel ID. diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 587ade7d..e48d49ab 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -5,6 +5,7 @@ import ( "crypto/sha256" "fmt" "image/color" + "math" "math/big" prand "math/rand" "net" @@ -354,6 +355,168 @@ func TestEdgeInsertionDeletion(t *testing.T) { } } +// TestDisconnecteBlockAtHeight checks that the pruned state of the channel +// database is what we expect after calling DisconnectBlockAtHeight. +func TestDisconnecteBlockAtHeight(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the insertion/deletion of edges, so we create two + // vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + var spendOutputs []*wire.OutPoint + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) + + // Prune the graph a few times to make sure we have entries in the + // prune log. + _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + var blockHash2 chainhash.Hash + copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) + + _, err = graph.PruneGraph(spendOutputs, &blockHash2, 156) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // We'll create 3 almost identical edges, so first create a helper + // method containing all logic for doing so. + createEdge := func(height uint32, txIndex uint32, txPosition uint16, + outPointIndex uint32) ChannelEdgeInfo { + shortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + TxIndex: txIndex, + TxPosition: txPosition, + } + outpoint := wire.OutPoint{ + Hash: rev, + Index: outPointIndex, + } + + edgeInfo := ChannelEdgeInfo{ + ChannelID: shortChanID.ToUint64(), + ChainHash: key, + NodeKey1: node1.PubKey, + NodeKey2: node2.PubKey, + BitcoinKey1: node1.PubKey, + BitcoinKey2: node2.PubKey, + AuthProof: &ChannelAuthProof{ + NodeSig1: testSig, + NodeSig2: testSig, + BitcoinSig1: testSig, + BitcoinSig2: testSig, + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + return edgeInfo + } + + // Create an edge which has its block height at 156. + height := uint32(156) + edgeInfo := createEdge(height, 0, 0, 0) + + // Create an edge with block height 157. We give it + // maximum values for tx index and position, to make + // sure our database range scan get edges from the + // entire range. + edgeInfo2 := createEdge(height+1, math.MaxUint32&0x00ffffff, + math.MaxUint16, 1) + + // Create a third edge, this with a block height of 155. + edgeInfo3 := createEdge(height-1, 0, 0, 2) + + // Now add all these new edges to the database. + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + if err := graph.AddChannelEdge(&edgeInfo2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + if err := graph.AddChannelEdge(&edgeInfo3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Call DisconnectBlockAtHeight, which should prune every channel + // that has an funding height of 'height' or greater. + removed, err := graph.DisconnectBlockAtHeight(uint32(height)) + if err != nil { + t.Fatalf("unable to prune %v", err) + } + + // The two edges should have been removed. + if len(removed) != 2 { + t.Fatalf("expected two edges to be removed from graph, "+ + "only %d were", len(removed)) + } + if removed[0].ChannelID != edgeInfo.ChannelID { + t.Fatalf("expected edge to be removed from graph") + } + if removed[1].ChannelID != edgeInfo2.ChannelID { + t.Fatalf("expected edge to be removed from graph") + } + + // The two first edges should be removed from the db. + _, _, has, err := graph.HasChannelEdge(edgeInfo.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if has { + t.Fatalf("edge1 was not pruned from the graph") + } + _, _, has, err = graph.HasChannelEdge(edgeInfo2.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if has { + t.Fatalf("edge2 was not pruned from the graph") + } + + // Edge 3 should not be removed. + _, _, has, err = graph.HasChannelEdge(edgeInfo3.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if !has { + t.Fatalf("edge3 was pruned from the graph") + } + + // PruneTip should be set to the blockHash we specified for the block + // at height 155. + hash, h, err := graph.PruneTip() + if err != nil { + t.Fatalf("unable to get prune tip: %v", err) + } + if !blockHash.IsEqual(hash) { + t.Fatalf("expected best block to be %x, was %x", blockHash, hash) + } + if h != height-1 { + t.Fatalf("expected best block height to be %d, was %d", height-1, h) + } +} + func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, e2 *ChannelEdgeInfo) { diff --git a/lnd_test.go b/lnd_test.go index 92b5c730..0553db8b 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -410,6 +410,186 @@ func testBasicChannelFunding(net *networkHarness, t *harnessTest) { closeChannelAndAssert(ctxt, t, net, net.Alice, chanPoint, false) } +// testOpenChannelAfterReorg tests that in the case where we have an open +// channel where the funding tx gets reorged out, the channel will no +// longer be present in the node's routing table. +func testOpenChannelAfterReorg(net *networkHarness, t *harnessTest) { + timeout := time.Duration(time.Second * 5) + ctxb := context.Background() + + // Set up a new miner that we can use to cause a reorg. + args := []string{"--rejectnonstd"} + miner, err := rpctest.New(harnessNetParams, + &rpcclient.NotificationHandlers{}, args) + if err != nil { + t.Fatalf("unable to create mining node: %v", err) + } + if err := miner.SetUp(true, 50); err != nil { + t.Fatalf("unable to set up mining node: %v", err) + } + defer miner.TearDown() + + if err := miner.Node.NotifyNewTransactions(false); err != nil { + t.Fatalf("unable to request transaction notifications: %v", err) + } + + // We start by connecting the new miner to our original miner, + // such that it will sync to our original chain. + if err := rpctest.ConnectNode(net.Miner, miner); err != nil { + t.Fatalf("unable to connect harnesses: %v", err) + } + nodeSlice := []*rpctest.Harness{net.Miner, miner} + if err := rpctest.JoinNodes(nodeSlice, rpctest.Blocks); err != nil { + t.Fatalf("unable to join node on blocks: %v", err) + } + + // The two should be on the same blockheight. + _, newNodeHeight, err := miner.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current blockheight %v", err) + } + + _, orgNodeHeight, err := net.Miner.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current blockheight %v", err) + } + + if newNodeHeight != orgNodeHeight { + t.Fatalf("expected new miner(%d) and original miner(%d) to "+ + "be on the same height", newNodeHeight, orgNodeHeight) + } + + // We disconnect the two nodes, such that we can start mining on them + // individually without the other one learning about the new blocks. + err = net.Miner.Node.AddNode(miner.P2PAddress(), rpcclient.ANRemove) + if err != nil { + t.Fatalf("unable to remove node: %v", err) + } + + // Create a new channel that requires 1 confs before it's considered + // open, then broadcast the funding transaction + chanAmt := maxFundingAmount + pushAmt := btcutil.Amount(0) + ctxt, _ := context.WithTimeout(ctxb, timeout) + pendingUpdate, err := net.OpenPendingChannel(ctxt, net.Alice, net.Bob, + chanAmt, pushAmt) + if err != nil { + t.Fatalf("unable to open channel: %v", err) + } + + // At this point, the channel's funding transaction will have been + // broadcast, but not confirmed, and the channel should be pending. + ctxt, _ = context.WithTimeout(ctxb, timeout) + assertNumOpenChannelsPending(ctxt, t, net.Alice, net.Bob, 1) + + fundingTxID, err := chainhash.NewHash(pendingUpdate.Txid) + if err != nil { + t.Fatalf("unable to convert funding txid into chainhash.Hash:"+ + " %v", err) + } + + // We now cause a fork, by letting our original miner mine 10 blocks, + // and our new miner mine 15. This will also confirm our pending + // channel, which should be considered open. + block := mineBlocks(t, net, 10)[0] + assertTxInBlock(t, block, fundingTxID) + miner.Node.Generate(15) + + // Ensure the chain lengths are what we expect. + _, newNodeHeight, err = miner.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current blockheight %v", err) + } + + _, orgNodeHeight, err = net.Miner.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current blockheight %v", err) + } + + if newNodeHeight != orgNodeHeight+5 { + t.Fatalf("expected new miner(%d) to be 5 blocks ahead of "+ + "original miner(%d)", newNodeHeight, orgNodeHeight) + } + + chanPoint := &lnrpc.ChannelPoint{ + FundingTxid: pendingUpdate.Txid, + OutputIndex: pendingUpdate.OutputIndex, + } + + // Ensure channel is no longer pending. + assertNumOpenChannelsPending(ctxt, t, net.Alice, net.Bob, 0) + + // Wait for Alice and Bob to recognize and advertise the new channel + // generated above. + ctxt, _ = context.WithTimeout(ctxb, timeout) + err = net.Alice.WaitForNetworkChannelOpen(ctxt, chanPoint) + if err != nil { + t.Fatalf("alice didn't advertise channel before "+ + "timeout: %v", err) + } + err = net.Bob.WaitForNetworkChannelOpen(ctxt, chanPoint) + if err != nil { + t.Fatalf("bob didn't advertise channel before "+ + "timeout: %v", err) + } + + // Alice should now have 1 edge in her graph. + req := &lnrpc.ChannelGraphRequest{} + chanGraph, err := net.Alice.DescribeGraph(ctxb, req) + if err != nil { + t.Fatalf("unable to query for alice's routing table: %v", err) + } + + numEdges := len(chanGraph.Edges) + if numEdges != 1 { + t.Fatalf("expected to find one edge in the graph, found %d", + numEdges) + } + + // Connecting the two miners should now cause our original one to sync + // to the new, and longer chain. + if err := rpctest.ConnectNode(net.Miner, miner); err != nil { + t.Fatalf("unable to connect harnesses: %v", err) + } + + if err := rpctest.JoinNodes(nodeSlice, rpctest.Blocks); err != nil { + t.Fatalf("unable to join node on blocks: %v", err) + } + + // Once again they should be on the same chain. + _, newNodeHeight, err = miner.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current blockheight %v", err) + } + + _, orgNodeHeight, err = net.Miner.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current blockheight %v", err) + } + + if newNodeHeight != orgNodeHeight { + t.Fatalf("expected new miner(%d) and original miner(%d) to "+ + "be on the same height", newNodeHeight, orgNodeHeight) + } + + // Since the fundingtx was reorged out, Alice should now have no edges + // in her graph. + req = &lnrpc.ChannelGraphRequest{} + chanGraph, err = net.Alice.DescribeGraph(ctxb, req) + if err != nil { + t.Fatalf("unable to query for alice's routing table: %v", err) + } + + numEdges = len(chanGraph.Edges) + if numEdges != 0 { + t.Fatalf("expected to find no edge in the graph, found %d", + numEdges) + } + + ctxt, _ = context.WithTimeout(ctxb, timeout) + closeChannelAndAssert(ctxt, t, net, net.Alice, chanPoint, false) +} + // testDisconnectingTargetPeer performs a test which // disconnects Alice-peer from Bob-peer and then re-connects them again func testDisconnectingTargetPeer(net *networkHarness, t *harnessTest) { @@ -3751,6 +3931,10 @@ var testsCases = []*testCase{ name: "basic funding flow", test: testBasicChannelFunding, }, + { + name: "open channel reorg test", + test: testOpenChannelAfterReorg, + }, { name: "disconnecting target peer", test: testDisconnectingTargetPeer, diff --git a/routing/chainview/btcd.go b/routing/chainview/btcd.go index 41cc5736..04c4ce84 100644 --- a/routing/chainview/btcd.go +++ b/routing/chainview/btcd.go @@ -1,14 +1,17 @@ package chainview import ( + "bytes" + "encoding/hex" "fmt" "sync" "sync/atomic" - "time" + "github.com/roasbeef/btcd/btcjson" "github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcd/rpcclient" "github.com/roasbeef/btcd/wire" + "github.com/roasbeef/btcutil" ) // BtcdFilteredChainView is an implementation of the FilteredChainView @@ -17,34 +20,27 @@ type BtcdFilteredChainView struct { started int32 stopped int32 - // bestHash is the hash of the latest block in the main chain. - bestHash chainhash.Hash - - // bestHeight is the height of the latest block in the main chain. - bestHeight int32 + // bestHeight is the height of the latest block added to the + // blockQueue from the onFilteredConnectedMethod. It is used to + // determine up to what height we would need to rescan in case + // of a filter update. + bestHeightMtx sync.Mutex + bestHeight uint32 btcdConn *rpcclient.Client - // newBlocks is the channel in which new filtered blocks are sent over. - newBlocks chan *FilteredBlock - - // staleBlocks is the channel in which blocks that have been - // disconnected from the mainchain are sent over. - staleBlocks chan *FilteredBlock + // blockEventQueue is the ordered queue used to keep the order + // of connected and disconnected blocks sent to the reader of the + // chainView. + blockQueue *blockEventQueue // filterUpdates is a channel in which updates to the utxo filter // attached to this instance are sent over. filterUpdates chan filterUpdate - // The three field below are used to implement a synchronized queue - // that lets use instantly handle sent notifications without blocking - // the main websockets notification loop. - chainUpdates []*chainUpdate - chainUpdateSignal chan struct{} - chainUpdateMtx sync.Mutex - // chainFilter is the set of utox's that we're currently watching // spends for within the chain. + filterMtx sync.RWMutex chainFilter map[wire.OutPoint]struct{} // filterBlockReqs is a channel in which requests to filter select @@ -63,18 +59,15 @@ var _ FilteredChainView = (*BtcdFilteredChainView)(nil) // RPC credentials for an active btcd instance. func NewBtcdFilteredChainView(config rpcclient.ConnConfig) (*BtcdFilteredChainView, error) { chainView := &BtcdFilteredChainView{ - newBlocks: make(chan *FilteredBlock), - staleBlocks: make(chan *FilteredBlock), - chainUpdateSignal: make(chan struct{}), - chainFilter: make(map[wire.OutPoint]struct{}), - filterUpdates: make(chan filterUpdate), - filterBlockReqs: make(chan *filterBlockReq), - quit: make(chan struct{}), + chainFilter: make(map[wire.OutPoint]struct{}), + filterUpdates: make(chan filterUpdate), + filterBlockReqs: make(chan *filterBlockReq), + quit: make(chan struct{}), } ntfnCallbacks := &rpcclient.NotificationHandlers{ - OnBlockConnected: chainView.onBlockConnected, - OnBlockDisconnected: chainView.onBlockDisconnected, + OnFilteredBlockConnected: chainView.onFilteredBlockConnected, + OnFilteredBlockDisconnected: chainView.onFilteredBlockDisconnected, } // Disable connecting to btcd within the rpcclient.New method. We @@ -87,6 +80,8 @@ func NewBtcdFilteredChainView(config rpcclient.ConnConfig) (*BtcdFilteredChainVi } chainView.btcdConn = chainConn + chainView.blockQueue = newBlockEventQueue() + return chainView, nil } @@ -110,12 +105,16 @@ func (b *BtcdFilteredChainView) Start() error { return err } - bestHash, bestHeight, err := b.btcdConn.GetBestBlock() + _, bestHeight, err := b.btcdConn.GetBestBlock() if err != nil { return err } - b.bestHash, b.bestHeight = *bestHash, bestHeight + b.bestHeightMtx.Lock() + b.bestHeight = uint32(bestHeight) + b.bestHeightMtx.Unlock() + + b.blockQueue.Start() b.wg.Add(1) go b.chainFilterer() @@ -137,6 +136,8 @@ func (b *BtcdFilteredChainView) Stop() error { // cleans up all related resources. b.btcdConn.Shutdown() + b.blockQueue.Stop() + log.Infof("FilteredChainView stopping") close(b.quit) @@ -145,39 +146,68 @@ func (b *BtcdFilteredChainView) Stop() error { return nil } -// chainUpdate encapsulates an update to the current main chain. This struct is -// used as an element within an unbounded queue in order to avoid blocking the -// main rpc dispatch rule. -type chainUpdate struct { - blockHash *chainhash.Hash - blockHeight int32 +// onFilteredBlockConnected is called for each block that's connected to the +// end of the main chain. Based on our current chain filter, the block may or +// may not include any relevant transactions. +func (b *BtcdFilteredChainView) onFilteredBlockConnected(height int32, + header *wire.BlockHeader, txns []*btcutil.Tx) { + + mtxs := make([]*wire.MsgTx, len(txns)) + for i, tx := range txns { + mtx := tx.MsgTx() + mtxs[i] = mtx + + for _, txIn := range mtx.TxIn { + // We can delete this outpoint from the chainFilter, as + // we just received a block where it was spent. In case + // of a reorg, this outpoint might get "un-spent", but + // that's okay since it would never be wise to consider + // the channel open again (since a spending transaction + // exists on the network). + b.filterMtx.Lock() + delete(b.chainFilter, txIn.PreviousOutPoint) + b.filterMtx.Unlock() + } + + } + + // We record the height of the last connected block added to the + // blockQueue such that we can scan up to this height in case of + // a rescan. It must be protected by a mutex since a filter update + // might be trying to read it concurrently. + b.bestHeightMtx.Lock() + b.bestHeight = uint32(height) + b.bestHeightMtx.Unlock() + + block := &FilteredBlock{ + Hash: header.BlockHash(), + Height: uint32(height), + Transactions: mtxs, + } + + b.blockQueue.Add(&blockEvent{ + eventType: connected, + block: block, + }) } -// onBlockConnected implements on OnBlockConnected callback for rpcclient. -// Ingesting a block updates the wallet's internal utxo state based on the -// outputs created and destroyed within each block. -func (b *BtcdFilteredChainView) onBlockConnected(hash *chainhash.Hash, - height int32, t time.Time) { +// onFilteredBlockDisconnected is a callback which is executed once a block is +// disconnected from the end of the main chain. +func (b *BtcdFilteredChainView) onFilteredBlockDisconnected(height int32, + header *wire.BlockHeader) { - // Append this new chain update to the end of the queue of new chain - // updates. - b.chainUpdateMtx.Lock() - b.chainUpdates = append(b.chainUpdates, &chainUpdate{hash, height}) - b.chainUpdateMtx.Unlock() + log.Debugf("got disconnected block at height %d: %v", height, + header.BlockHash()) - // Launch a goroutine to signal the notification dispatcher that a new - // block update is available. We do this in a new goroutine in order to - // avoid blocking the main loop of the rpc client. - go func() { - b.chainUpdateSignal <- struct{}{} - }() -} + filteredBlock := &FilteredBlock{ + Hash: header.BlockHash(), + Height: uint32(height), + } -// onBlockDisconnected implements on OnBlockDisconnected callback for rpcclient. -func (b *BtcdFilteredChainView) onBlockDisconnected(hash *chainhash.Hash, - height int32, t time.Time) { - - // TODO(roasbeef): impl + b.blockQueue.Add(&blockEvent{ + eventType: disconnected, + block: filteredBlock, + }) } // filterBlockReq houses a request to manually filter a block specified by @@ -231,7 +261,9 @@ func (b *BtcdFilteredChainView) chainFilterer() { if _, ok := b.chainFilter[prevOp]; ok { filteredTxns = append(filteredTxns, tx) + b.filterMtx.Lock() delete(b.chainFilter, prevOp) + b.filterMtx.Unlock() break } @@ -241,87 +273,118 @@ func (b *BtcdFilteredChainView) chainFilterer() { return filteredTxns } + decodeJSONBlock := func(block *btcjson.RescannedBlock, + height uint32) (*FilteredBlock, error) { + hash, err := chainhash.NewHashFromStr(block.Hash) + if err != nil { + return nil, err + + } + txs := make([]*wire.MsgTx, 0, len(block.Transactions)) + for _, str := range block.Transactions { + b, err := hex.DecodeString(str) + if err != nil { + return nil, err + } + tx := &wire.MsgTx{} + err = tx.Deserialize(bytes.NewReader(b)) + if err != nil { + return nil, err + } + txs = append(txs, tx) + } + return &FilteredBlock{ + Hash: *hash, + Height: height, + Transactions: txs, + }, nil + } + for { select { - - // A new block has been connected to the end of the main chain. - // So we'll need to dispatch a new FilteredBlock notification. - case <-b.chainUpdateSignal: - // A new update is available, so pop the new chain - // update from the front of the update queue. - b.chainUpdateMtx.Lock() - update := b.chainUpdates[0] - b.chainUpdates[0] = nil // Set to nil to prevent GC leak. - b.chainUpdates = b.chainUpdates[1:] - b.chainUpdateMtx.Unlock() - - // Now that we have the new block has, fetch the new - // block itself. - newBlock, err := b.btcdConn.GetBlock(update.blockHash) - if err != nil { - log.Errorf("Unable to get block: %v", err) - continue - } - b.bestHash, b.bestHeight = *update.blockHash, update.blockHeight - - // Next, we'll scan this block to see if it modified - // any of the UTXO set that we're watching. - filteredTxns := filterBlock(newBlock) - - // Finally, launch a goroutine to dispatch this - // filtered block notification. - go func() { - b.newBlocks <- &FilteredBlock{ - Hash: *update.blockHash, - Height: uint32(update.blockHeight), - Transactions: filteredTxns, - } - }() - // The caller has just sent an update to the current chain // filter, so we'll apply the update, possibly rewinding our // state partially. case update := <-b.filterUpdates: + // First, we'll add all the new UTXO's to the set of // watched UTXO's, eliminating any duplicates in the // process. log.Debugf("Updating chain filter with new UTXO's: %v", update.newUtxos) for _, newOp := range update.newUtxos { + b.filterMtx.Lock() b.chainFilter[newOp] = struct{}{} + b.filterMtx.Unlock() } + // Apply the new TX filter to btcd, which will cause + // all following notifications from and calls to it + // return blocks filtered with the new filter. + b.btcdConn.LoadTxFilter(false, []btcutil.Address{}, + update.newUtxos) + + // All blocks gotten after we loaded the filter will + // have the filter applied, but we will need to rescan + // the blocks up to the height of the block we last + // added to the blockQueue. + b.bestHeightMtx.Lock() + bestHeight := b.bestHeight + b.bestHeightMtx.Unlock() + // If the update height matches our best known height, // then we don't need to do any rewinding. - if update.updateHeight == uint32(b.bestHeight) { + if update.updateHeight == bestHeight { continue } // Otherwise, we'll rewind the state to ensure the // caller doesn't miss any relevant notifications. // Starting from the height _after_ the update height, - // we'll walk forwards, manually filtering blocks. - for i := int32(update.updateHeight) + 1; i < b.bestHeight+1; i++ { + // we'll walk forwards, rescanning one block at a time + // with btcd applying the newly loaded filter to each + // block. + for i := update.updateHeight + 1; i < bestHeight+1; i++ { blockHash, err := b.btcdConn.GetBlockHash(int64(i)) if err != nil { - log.Errorf("Unable to get block hash: %v", err) + log.Warnf("Unable to get block hash "+ + "for block at height %d: %v", + i, err) continue } - block, err := b.btcdConn.GetBlock(blockHash) + + // To avoid dealing with the case where a reorg + // is happening while we rescan, we scan one + // block at a time, skipping blocks that might + // have gone missing. + rescanned, err := b.btcdConn.RescanBlocks( + []chainhash.Hash{*blockHash}) if err != nil { - log.Errorf("Unable to get block: %v", err) + log.Warnf("Unable to rescan block "+ + "with hash %v at height %d: %v", + blockHash, i, err) continue } - filteredTxns := filterBlock(block) - - go func(height uint32) { - b.newBlocks <- &FilteredBlock{ - Hash: *blockHash, - Height: height, - Transactions: filteredTxns, - } - }(uint32(i)) + // If no block was returned from the rescan, + // it means no maching transactions were found. + if len(rescanned) != 1 { + log.Debugf("no matching block found "+ + "for rescan of hash %v", + blockHash) + continue + } + decoded, err := decodeJSONBlock( + &rescanned[0], uint32(i)) + if err != nil { + log.Errorf("Unable to decode block: %v", + err) + continue + } + b.blockQueue.Add(&blockEvent{ + eventType: connected, + block: decoded, + }) } // We've received a new request to manually filter a block. @@ -393,7 +456,7 @@ func (b *BtcdFilteredChainView) UpdateFilter(ops []wire.OutPoint, updateHeight u // // NOTE: This is part of the FilteredChainView interface. func (b *BtcdFilteredChainView) FilteredBlocks() <-chan *FilteredBlock { - return b.newBlocks + return b.blockQueue.newBlocks } // DisconnectedBlocks returns a receive only channel which will be sent upon @@ -402,5 +465,5 @@ func (b *BtcdFilteredChainView) FilteredBlocks() <-chan *FilteredBlock { // // NOTE: This is part of the FilteredChainView interface. func (b *BtcdFilteredChainView) DisconnectedBlocks() <-chan *FilteredBlock { - return b.staleBlocks + return b.blockQueue.staleBlocks } diff --git a/routing/chainview/interface.go b/routing/chainview/interface.go index 9e12d1c7..701296ec 100644 --- a/routing/chainview/interface.go +++ b/routing/chainview/interface.go @@ -18,12 +18,19 @@ type FilteredChainView interface { // FilteredBlocks returns the channel that filtered blocks are to be // sent over. Each time a block is connected to the end of a main // chain, and appropriate FilteredBlock which contains the transactions - // which mutate our watched UTXO set is to be returned. + // which mutate our watched UTXO set is to be returned. In case of a + // UpdateFilter call with a updateHeight lower than the current best + // height, blocks with the updated filter will be resent, and must be + // handled by the receiver as an update to an already known block, NOT + // as a new block being connected to the chain. FilteredBlocks() <-chan *FilteredBlock // DisconnectedBlocks returns a receive only channel which will be sent // upon with the empty filtered blocks of blocks which are disconnected // from the main chain in the case of a re-org. + // NOTE: In case of a reorg, connected blocks will not be available to + // read from the FilteredBlocks() channel before all disconnected block + // have been received. DisconnectedBlocks() <-chan *FilteredBlock // UpdateFilter updates the UTXO filter which is to be consulted when @@ -32,7 +39,9 @@ type FilteredChainView interface { // _expand_ the size of the UTXO sub-set currently being watched. If // the set updateHeight is _lower_ than the best known height of the // implementation, then the state should be rewound to ensure all - // relevant notifications are dispatched. + // relevant notifications are dispatched, meaning blocks with a height + // lower than the best known height might be sent over the + // FilteredBlocks() channel. UpdateFilter(ops []wire.OutPoint, updateHeight uint32) error // FilterBlock takes a block hash, and returns a FilteredBlocks which diff --git a/routing/chainview/interface_test.go b/routing/chainview/interface_test.go index 89a882a1..d0b0da07 100644 --- a/routing/chainview/interface_test.go +++ b/routing/chainview/interface_test.go @@ -19,8 +19,8 @@ import ( "github.com/roasbeef/btcd/txscript" "github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcutil" - "github.com/roasbeef/btcwallet/walletdb" + "github.com/roasbeef/btcwallet/walletdb" _ "github.com/roasbeef/btcwallet/walletdb/bdb" // Required to register the boltdb walletdb implementation. ) @@ -124,7 +124,8 @@ func assertFilteredBlock(t *testing.T, fb *FilteredBlock, expectedHeight int32, } func testFilterBlockNotifications(node *rpctest.Harness, - chainView FilteredChainView, t *testing.T) { + chainView FilteredChainView, chainViewInit chainViewInitFunc, + t *testing.T) { // To start the test, we'll create to fresh outputs paying to the // private key that we generated above. @@ -253,7 +254,8 @@ func testFilterBlockNotifications(node *rpctest.Harness, } } -func testUpdateFilterBackTrack(node *rpctest.Harness, chainView FilteredChainView, +func testUpdateFilterBackTrack(node *rpctest.Harness, + chainView FilteredChainView, chainViewInit chainViewInitFunc, t *testing.T) { // To start, we'll create a fresh output paying to the height generated @@ -321,6 +323,7 @@ func testUpdateFilterBackTrack(node *rpctest.Harness, chainView FilteredChainVie // After the block has been mined+notified we'll update the filter with // a _prior_ height so a "rewind" occurs. filter := []wire.OutPoint{*outPoint} + err = chainView.UpdateFilter(filter, uint32(currentHeight)) if err != nil { t.Fatalf("unable to update filter: %v", err) @@ -338,7 +341,7 @@ func testUpdateFilterBackTrack(node *rpctest.Harness, chainView FilteredChainVie } func testFilterSingleBlock(node *rpctest.Harness, chainView FilteredChainView, - t *testing.T) { + chainViewInit chainViewInitFunc, t *testing.T) { // In this test, we'll test the manual filtration of blocks, which can // be used by clients to manually rescan their sub-set of the UTXO set. @@ -451,9 +454,211 @@ func testFilterSingleBlock(node *rpctest.Harness, chainView FilteredChainView, expectedTxns) } +// testFilterBlockDisconnected triggers a reorg all the way back to genesis, +// and a small 5 block reorg, ensuring the chainView notifies about +// disconnected and connected blocks in the order we expect. +func testFilterBlockDisconnected(node *rpctest.Harness, + chainView FilteredChainView, chainViewInit chainViewInitFunc, + t *testing.T) { + + // Create a node that has a shorter chain than the main chain, so we + // can trigger a reorg. + reorgNode, err := rpctest.New(netParams, nil, nil) + if err != nil { + t.Fatalf("unable to create mining node: %v", err) + } + defer reorgNode.TearDown() + + // This node's chain will be 105 blocks. + if err := reorgNode.SetUp(true, 5); err != nil { + t.Fatalf("unable to set up mining node: %v", err) + } + + // Init a chain view that has this node as its block source. + cleanUpFunc, reorgView, err := chainViewInit(reorgNode.RPCConfig(), + reorgNode.P2PAddress()) + if err != nil { + t.Fatalf("unable to create chain view: %v", err) + } + defer func() { + if cleanUpFunc != nil { + cleanUpFunc() + } + }() + + if err = reorgView.Start(); err != nil { + t.Fatalf("unable to start btcd chain view: %v", err) + } + defer reorgView.Stop() + + newBlocks := reorgView.FilteredBlocks() + disconnectedBlocks := reorgView.DisconnectedBlocks() + + _, oldHeight, err := reorgNode.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current height: %v", err) + } + + // Now connect the node with the short chain to the main node, and wait + // for their chains to synchronize. The short chain will be reorged all + // the way back to genesis. + if err := rpctest.ConnectNode(reorgNode, node); err != nil { + t.Fatalf("unable to connect harnesses: %v", err) + } + nodeSlice := []*rpctest.Harness{node, reorgNode} + if err := rpctest.JoinNodes(nodeSlice, rpctest.Blocks); err != nil { + t.Fatalf("unable to join node on blocks: %v", err) + } + + _, newHeight, err := reorgNode.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current height: %v", err) + } + + // We should be getting oldHeight number of blocks marked as + // stale/disconnected. We expect to first get all stale blocks, + // then the new blocks. We also ensure a strict ordering. + for i := int32(0); i < oldHeight+newHeight; i++ { + select { + case block := <-newBlocks: + if i < oldHeight { + t.Fatalf("did not expect to get new block "+ + "in iteration %d", i) + } + expectedHeight := uint32(i - oldHeight + 1) + if block.Height != expectedHeight { + t.Fatalf("expected to receive connected "+ + "block at height %d, instead got at %d", + expectedHeight, block.Height) + } + case block := <-disconnectedBlocks: + if i >= oldHeight { + t.Fatalf("did not expect to get stale block "+ + "in iteration %d", i) + } + expectedHeight := uint32(oldHeight - i) + if block.Height != expectedHeight { + t.Fatalf("expected to receive disconencted "+ + "block at height %d, instead got at %d", + expectedHeight, block.Height) + } + case <-time.After(10 * time.Second): + t.Fatalf("timeout waiting for block") + } + } + + // Now we trigger a small reorg, by disconnecting the nodes, mining + // a few blocks on each, then connecting them again. + peers, err := reorgNode.Node.GetPeerInfo() + if err != nil { + t.Fatalf("unable to get peer info: %v", err) + } + numPeers := len(peers) + + // Disconnect the nodes. + err = reorgNode.Node.AddNode(node.P2PAddress(), rpcclient.ANRemove) + if err != nil { + t.Fatalf("unable to disconnect mining nodes: %v", err) + } + + // Wait for disconnection + for { + peers, err = reorgNode.Node.GetPeerInfo() + if err != nil { + t.Fatalf("unable to get peer info: %v", err) + } + if len(peers) < numPeers { + break + } + time.Sleep(100 * time.Millisecond) + } + + // Mine 10 blocks on the main chain, 5 on the chain that will be + // reorged out, + node.Node.Generate(10) + reorgNode.Node.Generate(5) + + // 5 new blocks should get notified. + for i := uint32(0); i < 5; i++ { + select { + case block := <-newBlocks: + expectedHeight := uint32(newHeight) + i + 1 + if block.Height != expectedHeight { + t.Fatalf("expected to receive connected "+ + "block at height %d, instead got at %d", + expectedHeight, block.Height) + } + case <-disconnectedBlocks: + t.Fatalf("did not expect to get stale block "+ + "in iteration %d", i) + case <-time.After(10 * time.Second): + t.Fatalf("did not get connected block") + } + } + + _, oldHeight, err = reorgNode.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current height: %v", err) + } + + // Now connect the two nodes, and wait for their chains to sync up. + if err := rpctest.ConnectNode(reorgNode, node); err != nil { + t.Fatalf("unable to connect harnesses: %v", err) + } + if err := rpctest.JoinNodes(nodeSlice, rpctest.Blocks); err != nil { + t.Fatalf("unable to join node on blocks: %v", err) + } + + _, newHeight, err = reorgNode.Node.GetBestBlock() + if err != nil { + t.Fatalf("unable to get current height: %v", err) + } + + // We should get 5 disconnected, 10 connected blocks. + for i := uint32(0); i < 15; i++ { + select { + case block := <-newBlocks: + if i < 5 { + t.Fatalf("did not expect to get new block "+ + "in iteration %d", i) + } + // The expected height for the connected block will be + // oldHeight - 5 (the 5 disconnected blocks) + (i-5) + // (subtract 5 since the 5 first iterations consumed + // disconnected blocks) + 1 + expectedHeight := uint32(oldHeight) - 9 + i + if block.Height != expectedHeight { + t.Fatalf("expected to receive connected "+ + "block at height %d, instead got at %d", + expectedHeight, block.Height) + } + case block := <-disconnectedBlocks: + if i >= 5 { + t.Fatalf("did not expect to get stale block "+ + "in iteration %d", i) + } + expectedHeight := uint32(oldHeight) - i + if block.Height != expectedHeight { + t.Fatalf("expected to receive disconnected "+ + "block at height %d, instead got at %d", + expectedHeight, block.Height) + } + case <-time.After(10 * time.Second): + t.Fatalf("did not get disconnected block") + } + } + + // Time for db access to finish between testcases. + time.Sleep(time.Millisecond * 500) +} + +type chainViewInitFunc func(rpcInfo rpcclient.ConnConfig, + p2pAddr string) (func(), FilteredChainView, error) + type testCase struct { name string - test func(*rpctest.Harness, FilteredChainView, *testing.T) + test func(*rpctest.Harness, FilteredChainView, chainViewInitFunc, + *testing.T) } var chainViewTests = []testCase{ @@ -469,12 +674,15 @@ var chainViewTests = []testCase{ name: "fitler single block", test: testFilterSingleBlock, }, + { + name: "filter block disconnected", + test: testFilterBlockDisconnected, + }, } var interfaceImpls = []struct { name string - chainViewInit func(rpcInfo rpcclient.ConnConfig, - p2pAddr string) (func(), FilteredChainView, error) + chainViewInit chainViewInitFunc }{ { name: "p2p_neutrino", @@ -569,9 +777,9 @@ func TestFilteredChainView(t *testing.T) { for _, chainViewTest := range chainViewTests { testName := fmt.Sprintf("%v: %v", chainViewImpl.name, chainViewTest.name) - success := t.Run(testName, func(t *testing.T) { - chainViewTest.test(miner, chainView, t) + chainViewTest.test(miner, chainView, + chainViewImpl.chainViewInit, t) }) if !success { diff --git a/routing/chainview/neutrino.go b/routing/chainview/neutrino.go index 15d415b9..b877ce94 100644 --- a/routing/chainview/neutrino.go +++ b/routing/chainview/neutrino.go @@ -35,16 +35,10 @@ type CfFilteredChainView struct { // rescan will be sent over. rescanErrChan <-chan error - // newBlocks is the channel in which new filtered blocks are sent over. - newBlocks chan *FilteredBlock - - // staleBlocks is the channel in which blocks that have been - // disconnected from the mainchain are sent over. - staleBlocks chan *FilteredBlock - - // filterUpdates is a channel in which updates to the utxo filter - // attached to this instance are sent over. - filterUpdates chan filterUpdate + // blockEventQueue is the ordered queue used to keep the order + // of connected and disconnected blocks sent to the reader of the + // chainView. + blockQueue *blockEventQueue // chainFilter is the filterMtx sync.RWMutex @@ -65,11 +59,9 @@ var _ FilteredChainView = (*CfFilteredChainView)(nil) // this function. func NewCfFilteredChainView(node *neutrino.ChainService) (*CfFilteredChainView, error) { return &CfFilteredChainView{ - newBlocks: make(chan *FilteredBlock), - staleBlocks: make(chan *FilteredBlock), + blockQueue: newBlockEventQueue(), quit: make(chan struct{}), rescanErrChan: make(chan error), - filterUpdates: make(chan filterUpdate), chainFilter: make(map[wire.OutPoint]struct{}), p2pNode: node, }, nil @@ -122,6 +114,8 @@ func (c *CfFilteredChainView) Start() error { c.chainView = c.p2pNode.NewRescan(rescanOptions...) c.rescanErrChan = c.chainView.Start() + c.blockQueue.Start() + c.wg.Add(1) go c.chainFilterer() @@ -140,6 +134,7 @@ func (c *CfFilteredChainView) Stop() error { log.Infof("FilteredChainView stopping") close(c.quit) + c.blockQueue.Stop() c.wg.Wait() return nil @@ -164,13 +159,16 @@ func (c *CfFilteredChainView) onFilteredBlockConnected(height int32, } - go func() { - c.newBlocks <- &FilteredBlock{ - Hash: header.BlockHash(), - Height: uint32(height), - Transactions: mtxs, - } - }() + block := &FilteredBlock{ + Hash: header.BlockHash(), + Height: uint32(height), + Transactions: mtxs, + } + + c.blockQueue.Add(&blockEvent{ + eventType: connected, + block: block, + }) } // onFilteredBlockDisconnected is a callback which is executed once a block is @@ -178,59 +176,29 @@ func (c *CfFilteredChainView) onFilteredBlockConnected(height int32, func (c *CfFilteredChainView) onFilteredBlockDisconnected(height int32, header *wire.BlockHeader) { + log.Debugf("got disconnected block at height %d: %v", height, + header.BlockHash()) + filteredBlock := &FilteredBlock{ Hash: header.BlockHash(), Height: uint32(height), } - go func() { - c.staleBlocks <- filteredBlock - }() + c.blockQueue.Add(&blockEvent{ + eventType: disconnected, + block: filteredBlock, + }) } // chainFilterer is the primary coordination goroutine within the -// CfFilteredChainView. This goroutine handles errors from the running rescan, -// and also filter updates. +// CfFilteredChainView. This goroutine handles errors from the running rescan. func (c *CfFilteredChainView) chainFilterer() { defer c.wg.Done() for { select { - case err := <-c.rescanErrChan: log.Errorf("Error encountered during rescan: %v", err) - - // We've received a new update to the filter from the caller to - // mutate their established chain view. - case update := <-c.filterUpdates: - log.Debugf("Updating chain filter with new UTXO's: %v", - update.newUtxos) - - // First, we'll update the current chain view, by - // adding any new UTXO's, ignoring duplicates int he - // process. - c.filterMtx.Lock() - for _, op := range update.newUtxos { - c.chainFilter[op] = struct{}{} - } - c.filterMtx.Unlock() - - // With our internal chain view update, we'll craft a - // new update to the chainView which includes our new - // UTXO's, and current update height. - rescanUpdate := []neutrino.UpdateOption{ - neutrino.AddOutPoints(update.newUtxos...), - neutrino.Rewind(update.updateHeight), - } - err := c.chainView.Update(rescanUpdate...) - if err != nil { - log.Errorf("unable to update rescan: %v", err) - } - - if update.done != nil { - close(update.done) - } - case <-c.quit: return } @@ -343,27 +311,32 @@ func (c *CfFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredB // rewound to ensure all relevant notifications are dispatched. // // NOTE: This is part of the FilteredChainView interface. -func (c *CfFilteredChainView) UpdateFilter(ops []wire.OutPoint, updateHeight uint32) error { - doneChan := make(chan struct{}) - update := filterUpdate{ - newUtxos: ops, - updateHeight: updateHeight, - done: doneChan, - } +func (c *CfFilteredChainView) UpdateFilter(ops []wire.OutPoint, + updateHeight uint32) error { + log.Debugf("Updating chain filter with new UTXO's: %v", ops) - select { - case c.filterUpdates <- update: - case <-c.quit: - return fmt.Errorf("chain filter shutting down") + // First, we'll update the current chain view, by + // adding any new UTXO's, ignoring duplicates in the + // process. + c.filterMtx.Lock() + for _, op := range ops { + c.chainFilter[op] = struct{}{} } + c.filterMtx.Unlock() - select { - case <-doneChan: - return nil - case <-c.quit: - return fmt.Errorf("chain filter shutting down") + // With our internal chain view update, we'll craft a + // new update to the chainView which includes our new + // UTXO's, and current update height. + rescanUpdate := []neutrino.UpdateOption{ + neutrino.AddOutPoints(ops...), + neutrino.Rewind(updateHeight), + neutrino.DisableDisconnectedNtfns(true), } - + err := c.chainView.Update(rescanUpdate...) + if err != nil { + return fmt.Errorf("unable to update rescan: %v", err) + } + return nil } // FilteredBlocks returns the channel that filtered blocks are to be sent over. @@ -373,7 +346,7 @@ func (c *CfFilteredChainView) UpdateFilter(ops []wire.OutPoint, updateHeight uin // // NOTE: This is part of the FilteredChainView interface. func (c *CfFilteredChainView) FilteredBlocks() <-chan *FilteredBlock { - return c.newBlocks + return c.blockQueue.newBlocks } // DisconnectedBlocks returns a receive only channel which will be sent upon @@ -382,5 +355,5 @@ func (c *CfFilteredChainView) FilteredBlocks() <-chan *FilteredBlock { // // NOTE: This is part of the FilteredChainView interface. func (c *CfFilteredChainView) DisconnectedBlocks() <-chan *FilteredBlock { - return c.staleBlocks + return c.blockQueue.staleBlocks } diff --git a/routing/chainview/queue.go b/routing/chainview/queue.go new file mode 100644 index 00000000..5d86c3c1 --- /dev/null +++ b/routing/chainview/queue.go @@ -0,0 +1,148 @@ +package chainview + +import "sync" + +// blockEventType is the possible types of a blockEvent. +type blockEventType uint8 + +const ( + // connected is the type of a blockEvent representing a block + // that was connected to our current chain. + connected blockEventType = iota + + // disconnected is the type of a blockEvent representing a + // block that is stale/disconnected from our current chain. + disconnected +) + +// blockEvent represent a block that was either connected +// or disconnected from the current chain. +type blockEvent struct { + eventType blockEventType + block *FilteredBlock +} + +// blockEventQueue is an ordered queue for block events sent from a +// FilteredChainView. The two types of possible block events are +// connected/new blocks, and disconnected/stale blocks. The +// blockEventQueue keeps the order of these events intact, while +// still being non-blocking. This is important in order for the +// chainView's call to onBlockConnected/onBlockDisconnected to not +// get blocked, and for the consumer of the block events to always +// get the events in the correct order. +type blockEventQueue struct { + queueCond *sync.Cond + queueMtx sync.Mutex + queue []*blockEvent + + // newBlocks is the channel where the consumer of the queue + // will receive connected/new blocks from the FilteredChainView. + newBlocks chan *FilteredBlock + + // stleBlocks is the channel where the consumer of the queue will + // receive disconnected/stale blocks from the FilteredChainView. + staleBlocks chan *FilteredBlock + + wg sync.WaitGroup + quit chan struct{} +} + +// newBlockEventQueue creates a new blockEventQueue. +func newBlockEventQueue() *blockEventQueue { + b := &blockEventQueue{ + newBlocks: make(chan *FilteredBlock), + staleBlocks: make(chan *FilteredBlock), + quit: make(chan struct{}), + } + b.queueCond = sync.NewCond(&b.queueMtx) + + return b +} + +// Start starts the blockEventQueue coordinator such that it can start handling +// events. +func (b *blockEventQueue) Start() { + b.wg.Add(1) + go b.queueCoordinator() +} + +// Stop signals the queue coordinator to stop, such that the queue can be +// shut down. +func (b *blockEventQueue) Stop() { + close(b.quit) + + b.queueCond.Signal() + b.wg.Wait() +} + +// queueCoordinator is the queue's main loop, handling incoming block events +// and handing them off to the correct output channel. +// +// NB: MUST be run as a goroutine from the Start() method. +func (b *blockEventQueue) queueCoordinator() { + defer b.wg.Done() + + for { + // First, we'll check our condition. If the queue of events is + // empty, then we'll wait until a new item is added. + b.queueCond.L.Lock() + for len(b.queue) == 0 { + b.queueCond.Wait() + + // If we were woke up in order to exit, then we'll do + // so. Otherwise, we'll check the queue for any new + // items. + select { + case <-b.quit: + b.queueCond.L.Unlock() + return + default: + } + } + + // Grab the first element in the queue, and nil the index to + // avoid gc leak. + event := b.queue[0] + b.queue[0] = nil + b.queue = b.queue[1:] + b.queueCond.L.Unlock() + + // In the case this is a connected block, we'll send it on the + // newBlocks channel. In case it is a disconnected block, we'll + // send it on the staleBlocks channel. This send will block + // until it is received by the consumer on the other end, making + // sure we won't try to send any other block event before the + // consumer is aware of this one. + switch event.eventType { + case connected: + select { + case b.newBlocks <- event.block: + case <-b.quit: + return + } + case disconnected: + select { + case b.staleBlocks <- event.block: + case <-b.quit: + return + } + } + } +} + +// Add puts the provided blockEvent at the end of the event queue, making sure +// it will first be received after all previous events. This method is +// non-blocking, in the sense that it will never wait for the consumer of the +// queue to read form the other end, making it safe to call from the +// FilteredChainView's onBlockConnected/onBlockDisconnected. +func (b *blockEventQueue) Add(event *blockEvent) { + + // Lock the condition, and add the event to the end of queue. + b.queueCond.L.Lock() + b.queue = append(b.queue, event) + b.queueCond.L.Unlock() + + // With the event added, we signal to the queueCoordinator that + // there are new events to handle. + b.queueCond.Signal() +} diff --git a/routing/notifications_test.go b/routing/notifications_test.go index e93bae46..1f24c398 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -147,7 +147,9 @@ func (m *mockChain) GetBestBlock() (*chainhash.Hash, int32, error) { m.RLock() defer m.RUnlock() - return nil, m.bestHeight, nil + blockHash := m.blockIndex[uint32(m.bestHeight)] + + return &blockHash, m.bestHeight, nil } func (m *mockChain) GetTransaction(txid *chainhash.Hash) (*wire.MsgTx, error) { @@ -162,7 +164,6 @@ func (m *mockChain) GetBlockHash(blockHeight int64) (*chainhash.Hash, error) { if !ok { return nil, fmt.Errorf("can't find block hash, for "+ "height %v", blockHeight) - } return &hash, nil @@ -185,9 +186,9 @@ func (m *mockChain) GetUtxo(op *wire.OutPoint, _ uint32) (*wire.TxOut, error) { return &utxo, nil } -func (m *mockChain) addBlock(block *wire.MsgBlock, height uint32) { +func (m *mockChain) addBlock(block *wire.MsgBlock, height uint32, nonce uint32) { m.Lock() - block.Header.Nonce = height + block.Header.Nonce = nonce hash := block.Header.BlockHash() m.blocks[hash] = block m.blockIndex[height] = hash @@ -250,6 +251,19 @@ func (m *mockChainView) notifyBlock(hash chainhash.Hash, height uint32, } } +func (m *mockChainView) notifyStaleBlock(hash chainhash.Hash, height uint32, + txns []*wire.MsgTx) { + + m.RLock() + defer m.RUnlock() + + m.staleBlocks <- &chainview.FilteredBlock{ + Hash: hash, + Height: height, + Transactions: txns, + } +} + func (m *mockChainView) FilteredBlocks() <-chan *chainview.FilteredBlock { return m.newBlocks } @@ -259,7 +273,7 @@ func (m *mockChainView) DisconnectedBlocks() <-chan *chainview.FilteredBlock { } func (m *mockChainView) FilterBlock(blockHash *chainhash.Hash) (*chainview.FilteredBlock, error) { - return nil, nil + return &chainview.FilteredBlock{}, nil } func (m *mockChainView) Start() error { @@ -295,7 +309,7 @@ func TestEdgeUpdateNotification(t *testing.T) { fundingBlock := &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // Next we'll create two test nodes that the fake channel will be open // between. @@ -477,7 +491,7 @@ func TestNodeUpdateNotification(t *testing.T) { fundingBlock := &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // Create two nodes acting as endpoints in the created channel, and use // them to trigger notifications by sending updated node announcement @@ -658,7 +672,7 @@ func TestNotificationCancellation(t *testing.T) { fundingBlock := &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // We'll create a fresh new node topology update to feed to the channel // router. @@ -743,7 +757,7 @@ func TestChannelCloseNotification(t *testing.T) { fundingBlock := &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // Next we'll create two test nodes that the fake channel will be open // between. @@ -797,7 +811,7 @@ func TestChannelCloseNotification(t *testing.T) { }, }, } - ctx.chain.addBlock(newBlock, blockHeight) + ctx.chain.addBlock(newBlock, blockHeight, blockHeight) ctx.chainView.notifyBlock(newBlock.Header.BlockHash(), blockHeight, newBlock.Transactions) diff --git a/routing/router.go b/routing/router.go index 59ceb22c..904aeade 100644 --- a/routing/router.go +++ b/routing/router.go @@ -185,9 +185,14 @@ type ChannelRouter struct { routeCache map[routeTuple][]*Route // newBlocks is a channel in which new blocks connected to the end of - // the main chain are sent over. + // the main chain are sent over, and blocks updated after a call to + // UpdateFilter. newBlocks <-chan *chainview.FilteredBlock + // staleBlocks is a channel in which blocks disconnected fromt the end + // of our currently known best chain are sent over. + staleBlocks <-chan *chainview.FilteredBlock + // networkUpdates is a channel that carries new topology updates // messages from outside the ChannelRouter to be processed by the // networkHandler. @@ -266,6 +271,7 @@ func (r *ChannelRouter) Start() error { // Once the instance is active, we'll fetch the channel we'll receive // notifications over. r.newBlocks = r.cfg.ChainView.FilteredBlocks() + r.staleBlocks = r.cfg.ChainView.DisconnectedBlocks() // Before we begin normal operation of the router, we first need to // synchronize the channel graph to the latest state of the UTXO set. @@ -352,6 +358,46 @@ func (r *ChannelRouter) syncGraphWithChain() error { return nil } + // If the main chain blockhash at prune height is different from the + // prune hash, this might indicate the database is on a stale branch. + mainBlockHash, err := r.cfg.Chain.GetBlockHash(int64(pruneHeight)) + if err != nil { + return err + } + + // While we are on a stale branch of the chain, walk backwards to find + // first common block. + for !pruneHash.IsEqual(mainBlockHash) { + log.Infof("channel graph is stale. Disconnecting block %v "+ + "(hash=%v)", pruneHeight, pruneHash) + // Prune the graph for every channel that was opened at height + // >= pruneHeigth. + _, err := r.cfg.Graph.DisconnectBlockAtHeight(pruneHeight) + if err != nil { + return err + } + + pruneHash, pruneHeight, err = r.cfg.Graph.PruneTip() + if err != nil { + switch { + // If at this point the graph has never been pruned, we + // can exit as this entails we are back to the point + // where it hasn't seen any block or created channels, + // alas there's nothing left to prune. + case err == channeldb.ErrGraphNeverPruned: + return nil + case err == channeldb.ErrGraphNotFound: + return nil + default: + return err + } + } + mainBlockHash, err = r.cfg.Chain.GetBlockHash(int64(pruneHeight)) + if err != nil { + return err + } + } + log.Infof("Syncing channel graph from height=%v (hash=%v) to height=%v "+ "(hash=%v)", pruneHeight, pruneHash, bestHeight, bestHash) @@ -449,6 +495,35 @@ func (r *ChannelRouter) networkHandler() { // after N blocks pass with no corresponding // announcements. + case chainUpdate, ok := <-r.staleBlocks: + // If the channel has been closed, then this indicates + // the daemon is shutting down, so we exit ourselves. + if !ok { + return + } + + // Since this block is stale, we update our best height + // to the previous block. + blockHeight := uint32(chainUpdate.Height) + r.bestHeight = blockHeight - 1 + + // Update the channel graph to reflect that this block + // was disconnected. + _, err := r.cfg.Graph.DisconnectBlockAtHeight(blockHeight) + if err != nil { + log.Errorf("unable to prune graph with stale "+ + "block: %v", err) + continue + } + + // Invalidate the route cache, as some channels might + // not be confirmed anymore. + r.routeCacheMtx.Lock() + r.routeCache = make(map[routeTuple][]*Route) + r.routeCacheMtx.Unlock() + + // TODO(halseth): notify client about the reorg? + // A new block has arrived, so we can prune the channel graph // of any channels which were closed in the block. case chainUpdate, ok := <-r.newBlocks: diff --git a/routing/router_test.go b/routing/router_test.go index 765b9297..c13c11c3 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "image/color" + "math/rand" "strings" "testing" "time" @@ -400,7 +401,7 @@ func TestAddProof(t *testing.T) { fundingBlock := &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) // After utxo was recreated adding the edge without the proof. edge := &channeldb.ChannelEdgeInfo{ @@ -502,7 +503,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { fundingBlock := &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) edge := &channeldb.ChannelEdgeInfo{ ChannelID: chanID.ToUint64(), @@ -600,7 +601,7 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { fundingBlock = &wire.MsgBlock{ Transactions: []*wire.MsgTx{fundingTx}, } - ctx.chain.addBlock(fundingBlock, chanID.BlockHeight) + ctx.chain.addBlock(fundingBlock, chanID.BlockHeight, chanID.BlockHeight) edge = &channeldb.ChannelEdgeInfo{ ChannelID: chanID.ToUint64(), @@ -718,3 +719,397 @@ func TestAddEdgeUnknownVertexes(t *testing.T) { t.Fatalf("fetched node not equal to original") } } + +// TestWakeUpOnStaleBranch tests that upon startup of the ChannelRouter, if the +// the chain previously reflected in the channel graph is stale (overtaken by a +// longer chain), the channel router will prune the graph for any channels +// confirmed on the stale chain, and resync to the main chain. +func TestWakeUpOnStaleBranch(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx, cleanUp, err := createTestCtx(startingBlockHeight) + defer cleanUp() + if err != nil { + t.Fatalf("unable to create router: %v", err) + } + + const chanValue = 10000 + + // chanID1 will not be reorged out. + var chanID1 uint64 + + // chanID2 will be reorged out. + var chanID2 uint64 + + // Create 10 common blocks, confirming chanID1. + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := startingBlockHeight + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID1 = chanID.ToUint64() + + } + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}) + } + + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + _, forkHeight, err := ctx.chain.GetBestBlock() + if err != nil { + t.Fatalf("unable to ge best block: %v", err) + } + + // Create 10 blocks on the minority chain, confirming chanID2. + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID2 = chanID.ToUint64() + } + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}) + } + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + // Now add the two edges to the channel graph, and check that they + // correctly show up in the database. + node1, err := createTestNode() + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestNode() + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + edge1 := &channeldb.ChannelEdgeInfo{ + ChannelID: chanID1, + NodeKey1: node1.PubKey, + NodeKey2: node2.PubKey, + BitcoinKey1: bitcoinKey1, + BitcoinKey2: bitcoinKey2, + AuthProof: &channeldb.ChannelAuthProof{ + NodeSig1: testSig, + NodeSig2: testSig, + BitcoinSig1: testSig, + BitcoinSig2: testSig, + }, + } + + if err := ctx.router.AddEdge(edge1); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + edge2 := &channeldb.ChannelEdgeInfo{ + ChannelID: chanID2, + NodeKey1: node1.PubKey, + NodeKey2: node2.PubKey, + BitcoinKey1: bitcoinKey1, + BitcoinKey2: bitcoinKey2, + AuthProof: &channeldb.ChannelAuthProof{ + NodeSig1: testSig, + NodeSig2: testSig, + BitcoinSig1: testSig, + BitcoinSig2: testSig, + }, + } + + if err := ctx.router.AddEdge(edge2); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // Check that the fundingTxs are in the graph db. + _, _, has, err := ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("could not find edge in graph") + } + + _, _, has, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if !has { + t.Fatalf("could not find edge in graph") + } + + // Stop the router, so we can reorg the chain while its offline. + if err := ctx.router.Stop(); err != nil { + t.Fatalf("unable to stop router: %v", err) + } + + // Create a 15 block fork. + for i := uint32(1); i <= 15; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + } + + // Give time to process new blocks. + time.Sleep(time.Millisecond * 500) + + // Create new router with same graph database. + router, err := New(Config{ + Graph: ctx.graph, + Chain: ctx.chain, + ChainView: ctx.chainView, + SendToSwitch: func(_ *btcec.PublicKey, + _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + return [32]byte{}, nil + }, + ChannelPruneExpiry: time.Hour * 24, + GraphPruneInterval: time.Hour * 2, + }) + if err != nil { + t.Fatalf("unable to create router %v", err) + } + + // It should resync to the longer chain on startup. + if err := router.Start(); err != nil { + t.Fatalf("unable to start router: %v", err) + } + + // The channel with chanID2 should not be in the database anymore, + // since it is not confirmed on the longest chain. chanID1 should + // still be. + _, _, has, err = ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("did not find edge in graph") + } + + _, _, has, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if has { + t.Fatalf("found edge in graph") + } + +} + +// TestDisconnectedBlocks checks that the router handles a reorg happening +// when it is active. +func TestDisconnectedBlocks(t *testing.T) { + t.Parallel() + + const startingBlockHeight = 101 + ctx, cleanUp, err := createTestCtx(startingBlockHeight) + defer cleanUp() + if err != nil { + t.Fatalf("unable to create router: %v", err) + } + + const chanValue = 10000 + + // chanID1 will not be reorged out. + var chanID1 uint64 + + // chanID2 will be reorged out. + var chanID2 uint64 + + // Create 10 common blocks, confirming chanID1. + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := startingBlockHeight + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID1 = chanID.ToUint64() + + } + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}) + } + + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + _, forkHeight, err := ctx.chain.GetBestBlock() + if err != nil { + t.Fatalf("unable to get best block: %v", err) + } + + // Create 10 blocks on the minority chain, confirming chanID2. + var minorityChain []*wire.MsgBlock + for i := uint32(1); i <= 10; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + if i == 5 { + fundingTx, _, chanID, err := createChannelEdge(ctx, + bitcoinKey1.SerializeCompressed(), + bitcoinKey2.SerializeCompressed(), + chanValue, height) + if err != nil { + t.Fatalf("unable create channel edge: %v", err) + } + block.Transactions = append(block.Transactions, + fundingTx) + chanID2 = chanID.ToUint64() + } + minorityChain = append(minorityChain, block) + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + []*wire.MsgTx{}) + } + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + // Now add the two edges to the channel graph, and check that they + // correctly show up in the database. + node1, err := createTestNode() + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestNode() + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + edge1 := &channeldb.ChannelEdgeInfo{ + ChannelID: chanID1, + NodeKey1: node1.PubKey, + NodeKey2: node2.PubKey, + BitcoinKey1: bitcoinKey1, + BitcoinKey2: bitcoinKey2, + AuthProof: &channeldb.ChannelAuthProof{ + NodeSig1: testSig, + NodeSig2: testSig, + BitcoinSig1: testSig, + BitcoinSig2: testSig, + }, + } + + if err := ctx.router.AddEdge(edge1); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + edge2 := &channeldb.ChannelEdgeInfo{ + ChannelID: chanID2, + NodeKey1: node1.PubKey, + NodeKey2: node2.PubKey, + BitcoinKey1: bitcoinKey1, + BitcoinKey2: bitcoinKey2, + AuthProof: &channeldb.ChannelAuthProof{ + NodeSig1: testSig, + NodeSig2: testSig, + BitcoinSig1: testSig, + BitcoinSig2: testSig, + }, + } + + if err := ctx.router.AddEdge(edge2); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // Check that the fundingTxs are in the graph db. + _, _, has, err := ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("could not find edge in graph") + } + + _, _, has, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if !has { + t.Fatalf("could not find edge in graph") + } + + // Create a 15 block fork. We first let the chainView notify the + // router about stale blocks, before sending the now connected + // blocks. We do this because we expect this order from the + // chainview. + for i := len(minorityChain) - 1; i >= 0; i-- { + block := minorityChain[i] + height := uint32(forkHeight) + uint32(i) + 1 + ctx.chainView.notifyStaleBlock(block.BlockHash(), height, + block.Transactions) + } + for i := uint32(1); i <= 15; i++ { + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + height := uint32(forkHeight) + i + ctx.chain.addBlock(block, height, rand.Uint32()) + ctx.chain.setBestBlock(int32(height)) + ctx.chainView.notifyBlock(block.BlockHash(), height, + block.Transactions) + } + + // Give time to process new blocks + time.Sleep(time.Millisecond * 500) + + // The with chanID2 should not be in the database anymore, since it is + // not confirmed on the longest chain. chanID1 should still be. + _, _, has, err = ctx.graph.HasChannelEdge(chanID1) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID1) + } + if !has { + t.Fatalf("did not find edge in graph") + } + + _, _, has, err = ctx.graph.HasChannelEdge(chanID2) + if err != nil { + t.Fatalf("error looking for edge: %v", chanID2) + } + if has { + t.Fatalf("found edge in graph") + } + +}