diff --git a/channeldb/graph.go b/channeldb/graph.go index 01f79b04..c945238a 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -1330,6 +1330,47 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) (*ChannelEdgeInfo, * return edgeInfo, policy1, policy2, nil } +// ChannelView returns the verifiable edge information for each active channel +// within the known channel graph. The set of UTXO's returned are the ones that +// need to be watched on chain to detect channel closes on the resident +// blockchain. +func (c *ChannelGraph) ChannelView() ([]wire.OutPoint, error) { + var chanPoints []wire.OutPoint + if err := c.db.View(func(tx *bolt.Tx) error { + // We're going to iterate over the entire channel index, so + // we'll need to fetch the edgeBucket to get to the index as + // it's a sub-bucket. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrGraphNoEdgesFound + } + + // Once we have the proper bucket, we'll range over each key + // (which is the channel point for the channel) and decode it, + // accumulating each entry. + return chanIndex.ForEach(func(chanPointBytes, _ []byte) error { + chanPointReader := bytes.NewReader(chanPointBytes) + + var chanPoint wire.OutPoint + err := readOutpoint(chanPointReader, &chanPoint) + if err != nil { + return err + } + + chanPoints = append(chanPoints, chanPoint) + return nil + }) + }); err != nil { + return nil, err + } + + return chanPoints, nil +} + // NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { return &ChannelEdgePolicy{db: c.db} diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 71a4beda..07206d46 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -683,7 +683,28 @@ func asserNumChans(t *testing.T, graph *ChannelGraph, n int) { } if numChans != n { _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: expected %v chans instead have %v", line, n, numChans) + t.Fatalf("line %v: expected %v chans instead have %v", line, + n, numChans) + } +} + +func assertChanViewEqual(t *testing.T, a []wire.OutPoint, b []*wire.OutPoint) { + if len(a) != len(b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chan views dont match", line) + } + + chanViewSet := make(map[wire.OutPoint]struct{}) + for _, op := range a { + chanViewSet[op] = struct{}{} + } + + for _, op := range b { + if _, ok := chanViewSet[*op]; !ok { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chanPoint(%v) not found in first view", + line, op) + } } } @@ -768,6 +789,14 @@ func TestGraphPruning(t *testing.T) { } } + // With all the channel points added, we'll consult the graph to ensure + // it has the same channel view as the one we just constructed. + channelView, err := graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + assertChanViewEqual(t, channelView, channelPoints) + // Now with our test graph created, we can test the pruning // capabilities of the channel graph. @@ -793,6 +822,13 @@ func TestGraphPruning(t *testing.T) { // should be remaining. asserNumChans(t, graph, 2) + // Those channels should also be missing from the channel view. + channelView, err = graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + assertChanViewEqual(t, channelView, channelPoints[2:]) + // Next we'll create a block that doesn't close any channels within the // graph to test the negative error case. fakeHash := sha256.Sum256([]byte("test prune")) @@ -833,12 +869,22 @@ func TestGraphPruning(t *testing.T) { 2, len(prunedChans)) } - // TODO(roasbeef): asser that proper chans have been closed - // The prune tip should be updated, and no channels should be found // within the current graph. assertPruneTip(t, graph, &blockHash, blockHeight) asserNumChans(t, graph, 0) + + // Finally, the channel view at this point in the graph should now be + // completely empty. + // Those channels should also be missing from the channel view. + channelView, err = graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + if len(channelView) != 0 { + t.Fatalf("channel view should be empty, instead have: %v", + channelView) + } } // compareNodes is used to compare two LightningNodes while excluding the