diff --git a/channeldb/db.go b/channeldb/db.go index 8e27b021..c4262ffa 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "sync" + "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/wire" @@ -869,6 +870,112 @@ func (d *DB) PruneLinkNodes() error { }) } +// ChannelShell is a shell of a channel that is meant to be used for channel +// recovery purposes. It contains a minimal OpenChannel instance along with +// addresses for that target node. +type ChannelShell struct { + // NodeAddrs the set of addresses that this node has known to be + // reachable at in the past. + NodeAddrs []net.Addr + + // Chan is a shell of an OpenChannel, it contains only the items + // required to restore the channel on disk. + Chan *OpenChannel +} + +// RestoreChannelShells is a method that allows the caller to reconstruct the +// state of an OpenChannel from the ChannelShell. We'll attempt to write the +// new channel to disk, create a LinkNode instance with the passed node +// addresses, and finally create an edge within the graph for the channel as +// well. This method is idempotent, so repeated calls with the same set of +// channel shells won't modify the database after the initial call. +func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { + chanGraph := ChannelGraph{d} + + return d.Update(func(tx *bbolt.Tx) error { + for _, channelShell := range channelShells { + channel := channelShell.Chan + + // First, we'll attempt to create a new open channel + // and link node for this channel. If the channel + // already exists, then in order to ensure this method + // is idempotent, we'll continue to the next step. + channel.Db = d + err := syncNewChannel( + tx, channel, channelShell.NodeAddrs, + ) + if err != nil { + return err + } + + // Next, we'll create an active edge in the graph + // database for this channel in order to restore our + // partial view of the network. + // + // TODO(roasbeef): if we restore *after* the channel + // has been closed on chain, then need to inform the + // router that it should try and prune these values as + // we can detect them + edgeInfo := ChannelEdgeInfo{ + ChannelID: channel.ShortChannelID.ToUint64(), + ChainHash: channel.ChainHash, + ChannelPoint: channel.FundingOutpoint, + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + selfNode, err := chanGraph.sourceNode(nodes) + if err != nil { + return err + } + + // Depending on which pub key is smaller, we'll assign + // our roles as "node1" and "node2". + chanPeer := channel.IdentityPub.SerializeCompressed() + selfIsSmaller := bytes.Compare( + selfNode.PubKeyBytes[:], chanPeer, + ) == -1 + if selfIsSmaller { + copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], chanPeer) + } else { + copy(edgeInfo.NodeKey1Bytes[:], chanPeer) + copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) + } + + // With the edge info shell constructed, we'll now add + // it to the graph. + err = chanGraph.addChannelEdge(tx, &edgeInfo) + if err != nil { + return err + } + + // Similarly, we'll construct a channel edge shell and + // add that itself to the graph. + chanEdge := ChannelEdgePolicy{ + ChannelID: edgeInfo.ChannelID, + LastUpdate: time.Now(), + } + + // If their pubkey is larger, then we'll flip the + // direction bit to indicate that us, the "second" node + // is updating their policy. + if !selfIsSmaller { + chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection + } + + err = updateEdgePolicy(tx, &chanEdge) + if err != nil { + return err + } + } + + return nil + }) +} + // AddrsForNode consults the graph and channel database for all addresses known // to the passed node public key. func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { diff --git a/channeldb/db_test.go b/channeldb/db_test.go index 17018333..261e0b36 100644 --- a/channeldb/db_test.go +++ b/channeldb/db_test.go @@ -2,16 +2,22 @@ package channeldb import ( "io/ioutil" + "math" + "math/rand" "net" "os" "path/filepath" "reflect" "testing" + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" ) func TestOpenWithCreate(t *testing.T) { @@ -271,3 +277,174 @@ func TestFetchChannel(t *testing.T) { t.Fatalf("expected query to fail") } } + +func genRandomChannelShell() (*ChannelShell, error) { + var testPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + + _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:]) + + var chanPoint wire.OutPoint + if _, err := rand.Read(chanPoint.Hash[:]); err != nil { + return nil, err + } + + pub.Curve = nil + + chanPoint.Index = uint32(rand.Intn(math.MaxUint16)) + + chanStatus := ChanStatusDefault | ChanStatusRestored + + var shaChainPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + revRoot, err := chainhash.NewHash(shaChainPriv[:]) + if err != nil { + return nil, err + } + shaChainProducer := shachain.NewRevocationProducer(*revRoot) + + return &ChannelShell{ + NodeAddrs: []net.Addr{&net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }}, + Chan: &OpenChannel{ + chanStatus: chanStatus, + ChainHash: rev, + FundingOutpoint: chanPoint, + ShortChannelID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + IdentityPub: pub, + LocalChanCfg: ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + CsvDelay: uint16(rand.Int63()), + }, + PaymentBasePoint: keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamily(rand.Int63()), + Index: uint32(rand.Int63()), + }, + }, + }, + RemoteCurrentRevocation: pub, + IsPending: false, + RevocationStore: shachain.NewRevocationStore(), + RevocationProducer: shaChainProducer, + }, + }, nil +} + +// TestRestoreChannelShells tests that we're able to insert a partially channel +// populated to disk. This is useful for channel recovery purposes. We should +// find the new channel shell on disk, and also the db should be populated with +// an edge for that channel. +func TestRestoreChannelShells(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First, we'll make our channel shell, it will only have the minimal + // amount of information required for us to initiate the data loss + // protection feature. + channelShell, err := genRandomChannelShell() + if err != nil { + t.Fatalf("unable to gen channel shell: %v", err) + } + + graph := cdb.ChannelGraph() + + // Before we can restore the channel, we'll need to make a source node + // in the graph as the channel edge we create will need to have a + // origin. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // With the channel shell constructed, we'll now insert it into the + // database with the restoration method. + if err := cdb.RestoreChannelShells(channelShell); err != nil { + t.Fatalf("unable to restore channel shell: %v", err) + } + + // Now that the channel has been inserted, we'll attempt to query for + // it to ensure we can properly locate it via various means. + // + // First, we'll attempt to query for all channels that we have with the + // node public key that was restored. + nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub) + if err != nil { + t.Fatalf("unable find channel: %v", err) + } + + // We should now find a single channel from the database. + if len(nodeChans) != 1 { + t.Fatalf("unable to find restored channel by node "+ + "pubkey: %v", err) + } + + // That single channel should have the proper channel point, and also + // the expected set of flags to indicate that it was a restored + // channel. + if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint { + t.Fatalf("wrong funding outpoint: expected %v, got %v", + nodeChans[0].FundingOutpoint, + channelShell.Chan.FundingOutpoint) + } + if !nodeChans[0].HasChanStatus(ChanStatusRestored) { + t.Fatalf("node has wrong status flags: %v", + nodeChans[0].chanStatus) + } + + // We should also be able to find the channel if we query for it + // directly. + _, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + + // We should also be able to find the link node that was inserted by + // its public key. + linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch link node: %v", err) + } + + // The node should have the same address, as specified in the channel + // shell. + if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) { + t.Fatalf("addr mismach: expected %v, got %v", + linkNode.Addresses, channelShell.NodeAddrs) + } + + // Finally, we'll ensure that the edge for the channel was properly + // inserted. + chanInfos, err := graph.FetchChanInfos( + []uint64{channelShell.Chan.ShortChannelID.ToUint64()}, + ) + if err != nil { + t.Fatalf("unable to find edges: %v", err) + } + + if len(chanInfos) != 1 { + t.Fatalf("wrong amount of chan infos: expected %v got %v", + len(chanInfos), 1) + } + + // We should only find a single edge. + if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil { + t.Fatalf("only a single edge should be inserted: %v", err) + } +}