diff --git a/channeldb/graph.go b/channeldb/graph.go index 842443ea..21f687f0 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -33,6 +33,14 @@ var ( // maps: source -> selfPubKey nodeBucket = []byte("graph-node") + // nodeUpdateIndexBucket is a sub-bucket of the nodeBucket. This bucket + // will be used to quickly look up the "freshness" of a node's last + // update to the network. The bucket only contains keys, and no values, + // it's mapping: + // + // maps: updateTime || nodeID -> nil + nodeUpdateIndexBucket = []byte("graph-node-update-index") + // sourceKey is a special key that resides within the nodeBucket. The // sourceKey maps a key to the public key of the "self node". sourceKey = []byte("source") @@ -74,6 +82,13 @@ var ( // maps: chanID -> pubKey1 || pubKey2 || restofEdgeInfo edgeIndexBucket = []byte("edge-index") + // edgeUpdateIndexBucket is a sub-bucket of the main edgeBucket. This + // bucket contains an index which allows us to gauge the "freshness" of + // a channel's last updates. + // + // maps: updateTime || chanID -> nil + edgeUpdateIndexBucket = []byte("edge-update-index") + // channelPointBucket maps a channel's full outpoint (txid:index) to // its short 8-byte channel ID. This bucket resides within the // edgeBucket above, and can be used to quickly remove an edge due to @@ -169,44 +184,13 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli return err } - // The first node is contained within the first half of - // the edge information. - node1Pub := edgeInfoBytes[:33] - edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) - if err != nil && err != ErrEdgeNotFound && - err != ErrGraphNodeNotFound { + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { return err } - // The targeted edge may have not been advertised - // within the network, so we ensure it's non-nil before - // dereferencing its attributes. - if edge1 != nil { - edge1.db = c.db - if edge1.Node != nil { - edge1.Node.db = c.db - } - } - - // Similarly, the second node is contained within the - // latter half of the edge information. - node2Pub := edgeInfoBytes[33:] - edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) - if err != nil && err != ErrEdgeNotFound && - err != ErrGraphNodeNotFound { - return err - } - - // The targeted edge may have not been advertised - // within the network, so we ensure it's non-nil before - // dereferencing its attributes. - if edge2 != nil { - edge2.db = c.db - if edge2.Node != nil { - edge2.Node.db = c.db - } - } - // With both edges read, execute the call back. IF this // function returns an error then the transaction will // be aborted. @@ -356,7 +340,14 @@ func addLightningNode(tx *bolt.Tx, node *LightningNode) error { return err } - return putLightningNode(nodes, aliases, node) + updateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return err + } + + return putLightningNode(nodes, aliases, updateIndex, node) } // LookupAlias attempts to return the alias as advertised by the target node. @@ -721,8 +712,9 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf if err != nil { return err } - err = delChannelByEdge(edges, edgeIndex, chanIndex, - &edgeInfo.ChannelPoint) + err = delChannelByEdge( + edges, edgeIndex, chanIndex, &edgeInfo.ChannelPoint, + ) if err != nil && err != ErrEdgeNotFound { return err } @@ -872,6 +864,378 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { return chanID, nil } +// TODO(roasbeef): allow updates to use Batch? + +// HighestChanID returns the "highest" known channel ID in the channel graph. +// This represents the "newest" channel from the PoV of the chain. This method +// can be used by peers to quickly determine if they're graphs are in sync. +func (c *ChannelGraph) HighestChanID() (uint64, error) { + var cid uint64 + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // In order to find the highest chan ID, we'll fetch a cursor + // and use that to seek to the "end" of our known rage. + cidCursor := edgeIndex.Cursor() + + lastChanID, _ := cidCursor.Last() + + // If there's no key, then this means that we don't actually + // know of any channels, so we'll return a predicable error. + if lastChanID == nil { + return ErrGraphNoEdgesFound + } + + // Otherwise, we'll de serialize the channel ID and return it + // to the caller. + cid = byteOrder.Uint64(lastChanID) + return nil + }) + if err != nil && err != ErrGraphNoEdgesFound { + return 0, err + } + + return cid, nil +} + +// ChannelEdge represents the complete set of information for a channel edge in +// the known channel graph. This struct couples the core information of the +// edge as well as each of the known advertised edge policies. +type ChannelEdge struct { + // Info contains all the static information describing the channel. + Info *ChannelEdgeInfo + + // Policy1 points to the "first" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy1 *ChannelEdgePolicy + + // Policy2 points to the "second" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy2 *ChannelEdgePolicy +} + +// ChanUpdatesInHorizon returns all the known channel edges which have at least +// one edge that has an update timestamp within the specified horizon. +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { + var edgesInHorizon []ChannelEdge + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return ErrGraphNoEdgesFound + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all channels within the horizon. + updateCursor := edgeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 8]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting the info and policy of each update of + // each channel that has a last update within the time range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + // We have a new eligible entry, so we'll slice of the + // chan ID so we can query it in the DB. + chanID := indexKey[8:] + + // First, we'll fetch the static edge information. + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + return err + } + + // Finally, we'll collate this edge with the rest of + // edges to be returned. + edgesInHorizon = append(edgesInHorizon, ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + }) + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + return edgesInHorizon, nil +} + +// NodeUpdatesInHorizon returns all the known lightning node which have an +// update timestamp within the passed range. This method can be used by two +// nodes to quickly determine if they have the same set of up to date node +// announcements. +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode + + err := c.db.View(func(tx *bolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) + if nodeUpdateIndex == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all node announcements within the horizon. + updateCursor := nodeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 33]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting info for each node within the time + // range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + nodePub := indexKey[8:] + node, err := fetchLightningNode(nodes, nodePub) + if err != nil { + return err + } + node.db = c.db + + nodesInHorizon = append(nodesInHorizon, node) + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + return nodesInHorizon, nil +} + +// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan +// ID's that we don't know of in the passed set. In other words, we perform a +// set difference of our set of chan ID's and the ones passed in. This method +// can be used by callers to determine the set of channels ta peer knows of +// that we don't. +func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { + var newChanIDs []uint64 + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // We'll run through the set of chanIDs and collate only the + // set of channel that are unable to be found within our db. + var cidBytes [8]byte + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + if v := edgeIndex.Get(cidBytes[:]); v == nil { + newChanIDs = append(newChanIDs, cid) + } + } + + return nil + }) + switch { + // If we don't know of any edges yet, then we'll return the entire set + // of chan IDs specified. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return newChanIDs, nil +} + +// FilterChannelRange returns the channel ID's of all known channels which were +// mined in a block height within the passed range. This method can be used to +// quickly share with a peer the set of channels we know of within a particular +// range to catch them up after a period of time offline. +func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { + var chanIDs []uint64 + + startChanID := &lnwire.ShortChannelID{ + BlockHeight: startHeight, + } + + endChanID := lnwire.ShortChannelID{ + BlockHeight: endHeight, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + + // As we need to perform a range scan, we'll convert the starting and + // ending height to their corresponding values when encoded using short + // channel ID's. + var chanIDStart, chanIDEnd [8]byte + byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) + byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + cursor := edgeIndex.Cursor() + + // We'll now iterate through the database, and find each + // channel ID that resides within the specified range. + var cid uint64 + for k, _ := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { + + // This channel ID rests within the target range, so + // we'll convert it into an integer and add it to our + // returned set. + cid = byteOrder.Uint64(k) + chanIDs = append(chanIDs, cid) + } + + return nil + }) + switch { + // If we don't know of any channels yet, then there's nothing to + // filter, so we'll return an empty slice. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return chanIDs, nil +} + +// FetchChanInfos returns the set of channel edges that correspond to the +// passed channel ID's. This can be used to respond to peer queries that are +// seeking to fill in gaps in their view of the channel graph. +func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { + // TODO(roasbeef): sort cids? + + var ( + chanEdges []ChannelEdge + cidBytes [8]byte + ) + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + // First, we'll fetch the static edge information. + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, cidBytes[:], + ) + if err != nil { + return err + } + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, cidBytes[:], c.db, + ) + if err != nil { + return err + } + + chanEdges = append(chanEdges, ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + }) + } + return nil + }) + if err != nil { + return nil, err + } + + return chanEdges, nil +} + func delChannelByEdge(edges *bolt.Bucket, edgeIndex *bolt.Bucket, chanIndex *bolt.Bucket, chanPoint *wire.OutPoint) error { var b bytes.Buffer @@ -1776,7 +2140,9 @@ func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { return &ChannelEdgePolicy{db: c.db} } -func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, node *LightningNode) error { +func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, + updateIndex *bolt.Bucket, node *LightningNode) error { + var ( scratch [16]byte b bytes.Buffer @@ -1803,8 +2169,8 @@ func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, node *L return err } - // If we got a node announcement for this node, we will have the rest of - // the data available. If not we don't have more data to write. + // If we got a node announcement for this node, we will have the rest + // of the data available. If not we don't have more data to write. if !node.HaveNodeAnnouncement { // Write HaveNodeAnnouncement=0. byteOrder.PutUint16(scratch[:2], 0) @@ -1860,8 +2226,33 @@ func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, node *L return err } - return nodeBucket.Put(nodePub, b.Bytes()) + // With the alias bucket updated, we'll now update the index that + // tracks the time series of node updates. + var indexKey [8 + 33]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + copy(indexKey[8:], nodePub) + // If there was already an old index entry for this node, then we'll + // delete the old one before we write the new entry. + if nodeBytes := nodeBucket.Get(nodePub); nodeBytes != nil { + // Extract out the old update time to we can reconstruct the + // prior index key to delete it from the index. + oldUpdateTime := nodeBytes[:8] + + var oldIndexKey [8 + 33]byte + copy(oldIndexKey[:8], oldUpdateTime) + copy(oldIndexKey[8:], nodePub) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + return nodeBucket.Put(nodePub, b.Bytes()) } func fetchLightningNode(nodeBucket *bolt.Bucket, @@ -2136,6 +2527,44 @@ func putChanEdgePolicy(edges *bolt.Bucket, edge *ChannelEdgePolicy, from, to []b return err } + // Before we write out the new edge, we'll create a new entry in the + // update index in order to keep it fresh. + var indexKey [8 + 8]byte + copy(indexKey[:], scratch[:]) + byteOrder.PutUint64(indexKey[8:], edge.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + // If there was already an entry for this edge, then we'll need to + // delete the old one to ensure we don't leave around any after-images. + if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil { + // In order to delete the old entry, we'll need to obtain the + // *prior* update time in order to delete it. To do this, we'll + // create an offset to slice in. Starting backwards, we'll + // create an offset than puts us right after the flags + // variable: + // + // * pubkeySize + fee+policySize + timelockSize + flagSize + updateEnd := 33 + (8 * 3) + 2 + 1 + updateStart := updateEnd - 8 + oldUpdateTime := edgeBytes[updateStart:updateEnd] + + var oldIndexKey [8 + 8]byte + copy(oldIndexKey[:], oldUpdateTime) + byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + return edges.Put(edgeKey[:], b.Bytes()[:]) } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 3e3327ea..591811be 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -373,6 +373,42 @@ func TestEdgeInsertionDeletion(t *testing.T) { } } +func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, + node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { + + shortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + TxIndex: txIndex, + TxPosition: txPosition, + } + outpoint := wire.OutPoint{ + Hash: rev, + Index: outPointIndex, + } + + node1Pub, _ := node1.PubKey() + node2Pub, _ := node2.PubKey() + edgeInfo := ChannelEdgeInfo{ + ChannelID: shortChanID.ToUint64(), + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + + copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) + + return edgeInfo, shortChanID +} + // TestDisconnectBlockAtHeight checks that the pruned state of the channel // database is what we expect after calling DisconnectBlockAtHeight. func TestDisconnectBlockAtHeight(t *testing.T) { @@ -419,54 +455,22 @@ func TestDisconnectBlockAtHeight(t *testing.T) { // We'll create 3 almost identical edges, so first create a helper // method containing all logic for doing so. - createEdge := func(height uint32, txIndex uint32, txPosition uint16, - outPointIndex uint32) ChannelEdgeInfo { - shortChanID := lnwire.ShortChannelID{ - BlockHeight: height, - TxIndex: txIndex, - TxPosition: txPosition, - } - outpoint := wire.OutPoint{ - Hash: rev, - Index: outPointIndex, - } - - node1Pub, _ := node1.PubKey() - node2Pub, _ := node2.PubKey() - edgeInfo := ChannelEdgeInfo{ - ChannelID: shortChanID.ToUint64(), - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 9000, - } - - copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - - return edgeInfo - } // Create an edge which has its block height at 156. height := uint32(156) - edgeInfo := createEdge(height, 0, 0, 0) + edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2) // Create an edge with block height 157. We give it // maximum values for tx index and position, to make // sure our database range scan get edges from the // entire range. - edgeInfo2 := createEdge(height+1, math.MaxUint32&0x00ffffff, - math.MaxUint16, 1) + edgeInfo2, _ := createEdge( + height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1, + node1, node2, + ) // Create a third edge, this with a block height of 155. - edgeInfo3 := createEdge(height-1, 0, 0, 2) + edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) // Now add all these new edges to the database. if err := graph.AddChannelEdge(&edgeInfo); err != nil { @@ -754,9 +758,15 @@ func TestEdgeInfoUpdates(t *testing.T) { func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { update := prand.Int63() + return newEdgePolicy(chanID, op, db, update) +} + +func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, + updateTime int64) *ChannelEdgePolicy { + return &ChannelEdgePolicy{ ChannelID: chanID, - LastUpdate: time.Unix(update, 0), + LastUpdate: time.Unix(updateTime, 0), TimeLockDelta: uint16(prand.Int63()), MinHTLC: lnwire.MilliSatoshi(prand.Int63()), FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), @@ -1164,13 +1174,711 @@ func TestGraphPruning(t *testing.T) { } } +// TestHighestChanID tests that we're able to properly retrieve the highest +// known channel ID in the database. +func TestHighestChanID(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we don't yet have any channels in the database, then we should + // get a channel ID of zero if we ask for the highest channel ID. + bestID, err := graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + if bestID != 0 { + t.Fatalf("best ID w/ no chan should be zero, is instead: %v", + bestID) + } + + // Next, we'll insert two channels into the database, with each channel + // connecting the same two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // The first channel with be at height 10, while the other will be at + // height 100. + edge1, _ := createEdge(10, 0, 0, 0, node1, node2) + edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) + + if err := graph.AddChannelEdge(&edge1); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + if err := graph.AddChannelEdge(&edge2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Now that the edges has been inserted, we'll query for the highest + // known channel ID in the database. + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID2.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID2.ToUint64(), bestID) + } + + // If we add another edge, then the current best chan ID should be + // updated as well. + edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edge3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID3.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID3.ToUint64(), bestID) + } +} + +// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known +// channel updates within a specific time horizon. It also tests that upon +// insertion of a new edge, the edge update index is updated properly. +func TestChanUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we issue an arbitrary query before any channel updates are + // inserted in the database, we should get zero results. + chanUpdates, err := graph.ChanUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to updates for updates: %v", err) + } + if len(chanUpdates) != 0 { + t.Fatalf("expected 0 chan updates, instead got %v", + len(chanUpdates)) + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll now create 10 channels between the two nodes, with update + // times 10 seconds after each other. + const numChans = 10 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge1.Flags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge2.Flags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + } + + // With our channels loaded, we'll now start our series of queries. + queryCases := []struct { + start time.Time + end time.Time + + resp []ChannelEdge + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we query for the start time, and 10 seconds directly + // after it, we should only get a single update, that first + // one. + { + start: time.Unix(1234, 0), + end: startTime.Add(time.Second * 10), + + resp: []ChannelEdge{edges[0]}, + }, + + // If we add 10 seconds past the first update, and then + // subtract 10 from the last update, then we should only get + // the 8 edges in the middle. + { + start: startTime.Add(time.Second * 10), + end: endTime.Add(-time.Second * 10), + + resp: edges[1:9], + }, + + // If we use the start and end time as is, we should get the + // entire range. + { + start: startTime, + end: endTime, + + resp: edges, + }, + } + for _, queryCase := range queryCases { + resp, err := graph.ChanUpdatesInHorizon( + queryCase.start, queryCase.end, + ) + if err != nil { + t.Fatalf("unable to query for updates: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v chans, got %v chans", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + chanExp := queryCase.resp[i] + chanRet := resp[i] + + assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) + + err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) + if err != nil { + t.Fatal(err) + } + compareEdgePolicies(chanExp.Policy2, chanRet.Policy2) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve +// the most recent node updates within a particular time horizon. +func TestNodeUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + startTime := time.Unix(1234, 0) + endTime := startTime + + // If we issue an arbitrary query before we insert any nodes into the + // database, then we shouldn't get any results back. + nodeUpdates, err := graph.NodeUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to query for node updates: %v", err) + } + if len(nodeUpdates) != 0 { + t.Fatalf("expected 0 node updates, instead got %v", + len(nodeUpdates)) + } + + // We'll create 10 node announcements, each with an update timestmap 10 + // seconds after the other. + const numNodes = 10 + nodeAnns := make([]LightningNode, 0, numNodes) + for i := 0; i < numNodes; i++ { + nodeAnn, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + + // The node ann will use the current end time as its last + // update them, then we'll add 10 seconds in order to create + // the proper update time for the next node announcement. + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + nodeAnn.LastUpdate = updateTime + + nodeAnns = append(nodeAnns, *nodeAnn) + + if err := graph.AddLightningNode(nodeAnn); err != nil { + t.Fatalf("unable to add lightning node: %v", err) + } + } + + queryCases := []struct { + start time.Time + end time.Time + + resp []LightningNode + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we skip he first time epoch with out start time, then we + // should get back every now but the first. + { + start: startTime.Add(time.Second * 10), + end: endTime, + + resp: nodeAnns[1:], + }, + + // If we query for the range as is, we should get all 10 + // announcements back. + { + start: startTime, + end: endTime, + + resp: nodeAnns, + }, + + // If we reduce the ending time by 10 seconds, then we should + // get all but the last node we inserted. + { + start: startTime, + end: endTime.Add(-time.Second * 10), + + resp: nodeAnns[:9], + }, + } + for _, queryCase := range queryCases { + resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) + if err != nil { + t.Fatalf("unable to query for nodes: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v nodes, got %v nodes", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + err := compareNodes(&queryCase.resp[i], &resp[i]) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestFilterKnownChanIDs tests that we're able to properly perform the set +// differences of an incoming set of channel ID's, and those that we already +// know of on disk. +func TestFilterKnownChanIDs(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we try to filter out a set of channel ID's before we even know of + // any channels, then we should get the entire set back. + preChanIDs := []uint64{1, 2, 3, 4} + filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + if !reflect.DeepEqual(preChanIDs, filteredIDs) { + t.Fatalf("chan IDs shouldn't have been filtered!") + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, we'll add 5 channel ID's to the graph, each of them having a + // block height 10 blocks after the previous. + const numChans = 5 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + } + + queryCases := []struct { + queryIDs []uint64 + + resp []uint64 + }{ + // If we attempt to filter out all chanIDs we know of, the + // response should be the empty set. + { + queryIDs: chanIDs, + }, + + // If we query for a set of ID's that we didn't insert, we + // should get the same set back. + { + queryIDs: []uint64{99, 100}, + resp: []uint64{99, 100}, + }, + + // If we query for a super-set of our the chan ID's inserted, + // we should only get those new chanIDs back. + { + queryIDs: append(chanIDs, []uint64{99, 101}...), + resp: []uint64{99, 101}, + }, + } + + for _, queryCase := range queryCases { + resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), + spew.Sdump(resp)) + } + } +} + +// TestFilterChannelRange tests that we're able to properly retrieve the full +// set of short channel ID's for a given block range. +func TestFilterChannelRange(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // If we try to filter a channel range before we have any channels + // inserted, we should get an empty slice of results. + resp, err := graph.FilterChannelRange(10, 100) + if err != nil { + t.Fatalf("unable to filter channels: %v", err) + } + if len(resp) != 0 { + t.Fatalf("expected zero chans, instead got %v", len(resp)) + } + + // To start, we'll create a set of channels, each mined in a block 10 + // blocks after the prior one. + startHeight := uint32(100) + endHeight := startHeight + const numChans = 10 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + chanHeight := endHeight + channel, chanID := createEdge( + uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + + endHeight += 10 + } + + // With our channels inserted, we'll construct a series of queries that + // we'll execute below in order to exercise the features of the + // FilterKnownChanIDs method. + queryCases := []struct { + startHeight uint32 + endHeight uint32 + + resp []uint64 + }{ + // If we query for the entire range, then we should get the same + // set of short channel IDs back. + { + startHeight: startHeight, + endHeight: endHeight, + + resp: chanIDs, + }, + + // If we query for a range of channels right before our range, we + // shouldn't get any results back. + { + startHeight: 0, + endHeight: 10, + }, + + // If we only query for the last height (range wise), we should + // only get that last channel. + { + startHeight: endHeight - 10, + endHeight: endHeight - 10, + + resp: chanIDs[9:], + }, + + // If we query for just the first height, we should only get a + // single channel back (the first one). + { + startHeight: startHeight, + endHeight: startHeight, + + resp: chanIDs[:1], + }, + } + for i, queryCase := range queryCases { + resp, err := graph.FilterChannelRange( + queryCase.startHeight, queryCase.endHeight, + ) + if err != nil { + t.Fatalf("unable to issue range query: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("case #%v: expected %v, got %v", i, + queryCase.resp, resp) + } + } +} + +// TestFetchChanInfos tests that we're able to properly retrieve the full set +// of ChannelEdge structs for a given set of short channel ID's. +func TestFetchChanInfos(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll make 5 test channels, ensuring we keep track of which channel + // ID corresponds to a particular ChannelEdge. + const numChans = 5 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + edgeQuery := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge1.Flags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge2.Flags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + + edgeQuery = append(edgeQuery, chanID.ToUint64()) + } + + // We'll now attempt to query for the range of channel ID's we just + // inserted into the database. We should get the exact same set of + // edges back. + resp, err := graph.FetchChanInfos(edgeQuery) + if err != nil { + t.Fatalf("unable to fetch chan edges: %v", err) + } + if len(resp) != len(edges) { + t.Fatalf("expected %v edges, instead got %v", len(edges), + len(resp)) + } + + for i := 0; i < len(resp); i++ { + err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info) + } +} + // compareNodes is used to compare two LightningNodes while excluding the // Features struct, which cannot be compared as the semantics for reserializing // the featuresMap have not been defined. func compareNodes(a, b *LightningNode) error { - if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("LastUpdate doesn't match: expected %#v, \n"+ - "got %#v", a.LastUpdate, b.LastUpdate) + if a.LastUpdate != b.LastUpdate { + return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ + "got %v", a.LastUpdate, b.LastUpdate) } if !reflect.DeepEqual(a.Addresses, b.Addresses) { return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ @@ -1208,7 +1916,7 @@ func compareEdgePolicies(a, b *ChannelEdgePolicy) error { "got %v", a.ChannelID, b.ChannelID) } if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("LastUpdate doesn't match: expected %#v, \n "+ + return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ "got %#v", a.LastUpdate, b.LastUpdate) } if a.Flags != b.Flags {