diff --git a/chan_series.go b/chan_series.go new file mode 100644 index 00000000..b13db5da --- /dev/null +++ b/chan_series.go @@ -0,0 +1,312 @@ +package main + +import ( + "time" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/discovery" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing" + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// chanSeries is an implementation of the discovery.ChannelGraphTimeSeries +// interface backed by the channeldb ChannelGraph database. We'll provide this +// implementation to the AuthenticatedGossiper so it can properly use the +// in-protocol channel range queries to quickly and efficiently synchronize our +// channel state with all peers. +type chanSeries struct { + graph *channeldb.ChannelGraph +} + +// HighestChanID should return is the channel ID of the channel we know of +// that's furthest in the target chain. This channel will have a block height +// that's close to the current tip of the main chain as we know it. We'll use +// this to start our QueryChannelRange dance with the remote node. +// +// NOTE: This is part of the discovery.ChannelGraphTimeSeries interface. +func (c *chanSeries) HighestChanID(chain chainhash.Hash) (*lnwire.ShortChannelID, error) { + chanID, err := c.graph.HighestChanID() + if err != nil { + return nil, err + } + + shortChanID := lnwire.NewShortChanIDFromInt(chanID) + return &shortChanID, nil +} + +// UpdatesInHorizon returns all known channel and node updates with an update +// timestamp between the start time and end time. We'll use this to catch up a +// remote node to the set of channel updates that they may have missed out on +// within the target chain. +// +// NOTE: This is part of the discovery.ChannelGraphTimeSeries interface. +func (c *chanSeries) UpdatesInHorizon(chain chainhash.Hash, + startTime time.Time, endTime time.Time) ([]lnwire.Message, error) { + + var updates []lnwire.Message + + // First, we'll query for all the set of channels that have an update + // that falls within the specified horizon. + chansInHorizon, err := c.graph.ChanUpdatesInHorizon( + startTime, endTime, + ) + if err != nil { + return nil, err + } + for _, channel := range chansInHorizon { + // If the channel hasn't been fully advertised yet, or is a + // private channel, then we'll skip it as we can't construct a + // full authentication proof if one is requested. + if channel.Info.AuthProof == nil { + continue + } + + chanAnn, edge1, edge2, err := discovery.CreateChanAnnouncement( + channel.Info.AuthProof, channel.Info, channel.Policy1, + channel.Policy2, + ) + if err != nil { + return nil, err + } + + updates = append(updates, chanAnn) + if edge1 != nil { + updates = append(updates, edge1) + } + if edge2 != nil { + updates = append(updates, edge2) + } + } + + // Next, we'll send out all the node announcements that have an update + // within the horizon as well. We send these second to ensure that they + // follow any active channels they have. + nodeAnnsInHorizon, err := c.graph.NodeUpdatesInHorizon( + startTime, endTime, + ) + if err != nil { + return nil, err + } + for _, nodeAnn := range nodeAnnsInHorizon { + nodeUpdate, err := makeNodeAnn(&nodeAnn) + if err != nil { + return nil, err + } + + updates = append(updates, nodeUpdate) + } + + return updates, nil +} + +// FilterKnownChanIDs takes a target chain, and a set of channel ID's, and +// returns a filtered set of chan ID's. This filtered set of chan ID's +// represents the ID's that we don't know of which were in the passed superSet. +// +// NOTE: This is part of the discovery.ChannelGraphTimeSeries interface. +func (c *chanSeries) FilterKnownChanIDs(chain chainhash.Hash, + superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) { + + chanIDs := make([]uint64, 0, len(superSet)) + for _, chanID := range superSet { + chanIDs = append(chanIDs, chanID.ToUint64()) + } + + newChanIDs, err := c.graph.FilterKnownChanIDs(chanIDs) + if err != nil { + return nil, err + } + + filteredIDs := make([]lnwire.ShortChannelID, 0, len(newChanIDs)) + for _, chanID := range newChanIDs { + filteredIDs = append( + filteredIDs, lnwire.NewShortChanIDFromInt(chanID), + ) + } + + return filteredIDs, nil +} + +// FilterChannelRange returns the set of channels that we created between the +// start height and the end height. We'll use this respond to a remote peer's +// QueryChannelRange message. +// +// NOTE: This is part of the discovery.ChannelGraphTimeSeries interface. +func (c *chanSeries) FilterChannelRange(chain chainhash.Hash, + startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { + + chansInRange, err := c.graph.FilterChannelRange(startHeight, endHeight) + if err != nil { + return nil, err + } + + chanResp := make([]lnwire.ShortChannelID, 0, len(chansInRange)) + for _, chanID := range chansInRange { + chanResp = append( + chanResp, lnwire.NewShortChanIDFromInt(chanID), + ) + } + + return chanResp, nil +} + +func makeNodeAnn(n *channeldb.LightningNode) (*lnwire.NodeAnnouncement, error) { + alias, _ := lnwire.NewNodeAlias(n.Alias) + + wireSig, err := lnwire.NewSigFromRawSignature(n.AuthSigBytes) + if err != nil { + return nil, err + } + return &lnwire.NodeAnnouncement{ + Signature: wireSig, + Timestamp: uint32(n.LastUpdate.Unix()), + Addresses: n.Addresses, + NodeID: n.PubKeyBytes, + Features: n.Features.RawFeatureVector, + RGBColor: n.Color, + Alias: alias, + }, nil +} + +// FetchChanAnns returns a full set of channel announcements as well as their +// updates that match the set of specified short channel ID's. We'll use this +// to reply to a QueryShortChanIDs message sent by a remote peer. The response +// will contain a unique set of ChannelAnnouncements, the latest ChannelUpdate +// for each of the announcements, and a unique set of NodeAnnouncements. +// +// NOTE: This is part of the discovery.ChannelGraphTimeSeries interface. +func (c *chanSeries) FetchChanAnns(chain chainhash.Hash, + shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) { + + chanIDs := make([]uint64, 0, len(shortChanIDs)) + for _, chanID := range shortChanIDs { + chanIDs = append(chanIDs, chanID.ToUint64()) + } + + channels, err := c.graph.FetchChanInfos(chanIDs) + if err != nil { + return nil, err + } + + // We'll use this map to ensure we don't send the same node + // announcement more than one time as one node may have many channel + // anns we'll need to send. + nodePubsSent := make(map[routing.Vertex]struct{}) + + chanAnns := make([]lnwire.Message, 0, len(channels)*3) + for _, channel := range channels { + // If the channel doesn't have an authentication proof, then we + // won't send it over as it may not yet be finalized, or be a + // non-advertised channel. + if channel.Info.AuthProof == nil { + continue + } + + chanAnn, edge1, edge2, err := discovery.CreateChanAnnouncement( + channel.Info.AuthProof, channel.Info, channel.Policy1, + channel.Policy2, + ) + if err != nil { + return nil, err + } + + chanAnns = append(chanAnns, chanAnn) + if edge1 != nil { + chanAnns = append(chanAnns, edge1) + + // If this edge has a validated node announcement, that + // we haven't yet sent, then we'll send that as well. + nodePub := channel.Policy1.Node.PubKeyBytes + hasNodeAnn := channel.Policy1.Node.HaveNodeAnnouncement + if _, ok := nodePubsSent[nodePub]; !ok && hasNodeAnn { + nodeAnn, err := makeNodeAnn(channel.Policy1.Node) + if err != nil { + return nil, err + } + + chanAnns = append(chanAnns, nodeAnn) + nodePubsSent[nodePub] = struct{}{} + } + } + if edge2 != nil { + chanAnns = append(chanAnns, edge2) + + // If this edge has a validated node announcement, that + // we haven't yet sent, then we'll send that as well. + nodePub := channel.Policy2.Node.PubKeyBytes + hasNodeAnn := channel.Policy2.Node.HaveNodeAnnouncement + if _, ok := nodePubsSent[nodePub]; !ok && hasNodeAnn { + nodeAnn, err := makeNodeAnn(channel.Policy2.Node) + if err != nil { + return nil, err + } + + chanAnns = append(chanAnns, nodeAnn) + nodePubsSent[nodePub] = struct{}{} + } + } + } + + return chanAnns, nil +} + +// FetchChanUpdates returns the latest channel update messages for the +// specified short channel ID. If no channel updates are known for the channel, +// then an empty slice will be returned. +// +// NOTE: This is part of the discovery.ChannelGraphTimeSeries interface. +func (c *chanSeries) FetchChanUpdates(chain chainhash.Hash, + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) { + + chanInfo, e1, e2, err := c.graph.FetchChannelEdgesByID( + shortChanID.ToUint64(), + ) + if err != nil { + return nil, err + } + + chanUpdates := make([]*lnwire.ChannelUpdate, 0, 2) + if e1 != nil { + chanUpdate := &lnwire.ChannelUpdate{ + ChainHash: chanInfo.ChainHash, + ShortChannelID: shortChanID, + Timestamp: uint32(e1.LastUpdate.Unix()), + Flags: e1.Flags, + TimeLockDelta: e1.TimeLockDelta, + HtlcMinimumMsat: e1.MinHTLC, + BaseFee: uint32(e1.FeeBaseMSat), + FeeRate: uint32(e1.FeeProportionalMillionths), + } + chanUpdate.Signature, err = lnwire.NewSigFromRawSignature(e1.SigBytes) + if err != nil { + return nil, err + } + + chanUpdates = append(chanUpdates, chanUpdate) + } + if e2 != nil { + chanUpdate := &lnwire.ChannelUpdate{ + ChainHash: chanInfo.ChainHash, + ShortChannelID: shortChanID, + Timestamp: uint32(e2.LastUpdate.Unix()), + Flags: e2.Flags, + TimeLockDelta: e2.TimeLockDelta, + HtlcMinimumMsat: e2.MinHTLC, + BaseFee: uint32(e2.FeeBaseMSat), + FeeRate: uint32(e2.FeeProportionalMillionths), + } + chanUpdate.Signature, err = lnwire.NewSigFromRawSignature(e2.SigBytes) + if err != nil { + return nil, err + } + + chanUpdates = append(chanUpdates, chanUpdate) + } + + return chanUpdates, nil +} + +// A compile-time assertion to ensure that chanSeries meets the +// discovery.ChannelGraphTimeSeries interface. +var _ discovery.ChannelGraphTimeSeries = (*chanSeries)(nil) diff --git a/channeldb/db.go b/channeldb/db.go index 8292febb..076b1c2c 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -40,6 +40,13 @@ var ( number: 0, migration: nil, }, + { + // The version of the database where two new indexes + // for the update time of node and channel updates were + // added. + number: 1, + migration: migrateNodeAndEdgeUpdateIndex, + }, } // Big endian is the preferred byte order, due to cursor scans over @@ -523,8 +530,9 @@ func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, er // MarkChanFullyClosed marks a channel as fully closed within the database. A // channel should be marked as fully closed if the channel was initially -// cooperatively closed and it's reached a single confirmation, or after all the -// pending funds in a channel that has been forcibly closed have been swept. +// cooperatively closed and it's reached a single confirmation, or after all +// the pending funds in a channel that has been forcibly closed have been +// swept. func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { return d.Update(func(tx *bolt.Tx) error { var b bytes.Buffer @@ -594,8 +602,9 @@ func (d *DB) syncVersions(versions []version) error { // Otherwise, we fetch the migrations which need to applied, and // execute them serially within a single database transaction to ensure // the migration is atomic. - migrations, migrationVersions := getMigrationsToApply(versions, - meta.DbVersionNumber) + migrations, migrationVersions := getMigrationsToApply( + versions, meta.DbVersionNumber, + ) return d.Update(func(tx *bolt.Tx) error { for i, migration := range migrations { if migration == nil { diff --git a/channeldb/graph.go b/channeldb/graph.go index 842443ea..2695645a 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -33,6 +33,14 @@ var ( // maps: source -> selfPubKey nodeBucket = []byte("graph-node") + // nodeUpdateIndexBucket is a sub-bucket of the nodeBucket. This bucket + // will be used to quickly look up the "freshness" of a node's last + // update to the network. The bucket only contains keys, and no values, + // it's mapping: + // + // maps: updateTime || nodeID -> nil + nodeUpdateIndexBucket = []byte("graph-node-update-index") + // sourceKey is a special key that resides within the nodeBucket. The // sourceKey maps a key to the public key of the "self node". sourceKey = []byte("source") @@ -74,6 +82,13 @@ var ( // maps: chanID -> pubKey1 || pubKey2 || restofEdgeInfo edgeIndexBucket = []byte("edge-index") + // edgeUpdateIndexBucket is a sub-bucket of the main edgeBucket. This + // bucket contains an index which allows us to gauge the "freshness" of + // a channel's last updates. + // + // maps: updateTime || chanID -> nil + edgeUpdateIndexBucket = []byte("edge-update-index") + // channelPointBucket maps a channel's full outpoint (txid:index) to // its short 8-byte channel ID. This bucket resides within the // edgeBucket above, and can be used to quickly remove an edge due to @@ -169,44 +184,13 @@ func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePoli return err } - // The first node is contained within the first half of - // the edge information. - node1Pub := edgeInfoBytes[:33] - edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) - if err != nil && err != ErrEdgeNotFound && - err != ErrGraphNodeNotFound { + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { return err } - // The targeted edge may have not been advertised - // within the network, so we ensure it's non-nil before - // dereferencing its attributes. - if edge1 != nil { - edge1.db = c.db - if edge1.Node != nil { - edge1.Node.db = c.db - } - } - - // Similarly, the second node is contained within the - // latter half of the edge information. - node2Pub := edgeInfoBytes[33:] - edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) - if err != nil && err != ErrEdgeNotFound && - err != ErrGraphNodeNotFound { - return err - } - - // The targeted edge may have not been advertised - // within the network, so we ensure it's non-nil before - // dereferencing its attributes. - if edge2 != nil { - edge2.db = c.db - if edge2.Node != nil { - edge2.Node.db = c.db - } - } - // With both edges read, execute the call back. IF this // function returns an error then the transaction will // be aborted. @@ -356,7 +340,14 @@ func addLightningNode(tx *bolt.Tx, node *LightningNode) error { return err } - return putLightningNode(nodes, aliases, node) + updateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return err + } + + return putLightningNode(nodes, aliases, updateIndex, node) } // LookupAlias attempts to return the alias as advertised by the target node. @@ -595,6 +586,10 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, if err != nil { return err } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } // For each of the outpoints that have been spent within the // block, we attempt to delete them from the graph as if that @@ -628,8 +623,9 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // will be returned if that outpoint isn't known to be // a channel. If no error is returned, then a channel // was successfully pruned. - err = delChannelByEdge(edges, edgeIndex, chanIndex, - chanPoint) + err = delChannelByEdge( + edges, edgeIndex, chanIndex, nodes, chanPoint, + ) if err != nil && err != ErrEdgeNotFound { return err } @@ -699,16 +695,18 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf if err != nil { return err } - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) if err != nil { return err } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) if err != nil { return err } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } // Scan from chanIDStart to chanIDEnd, deleting every // found edge. @@ -721,8 +719,9 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInf if err != nil { return err } - err = delChannelByEdge(edges, edgeIndex, chanIndex, - &edgeInfo.ChannelPoint) + err = delChannelByEdge( + edges, edgeIndex, chanIndex, nodes, &edgeInfo.ChannelPoint, + ) if err != nil && err != ErrEdgeNotFound { return err } @@ -831,8 +830,12 @@ func (c *ChannelGraph) DeleteChannelEdge(chanPoint *wire.OutPoint) error { if err != nil { return err } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } - return delChannelByEdge(edges, edgeIndex, chanIndex, chanPoint) + return delChannelByEdge(edges, edgeIndex, chanIndex, nodes, chanPoint) }) } @@ -872,38 +875,448 @@ func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { return chanID, nil } +// TODO(roasbeef): allow updates to use Batch? + +// HighestChanID returns the "highest" known channel ID in the channel graph. +// This represents the "newest" channel from the PoV of the chain. This method +// can be used by peers to quickly determine if they're graphs are in sync. +func (c *ChannelGraph) HighestChanID() (uint64, error) { + var cid uint64 + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // In order to find the highest chan ID, we'll fetch a cursor + // and use that to seek to the "end" of our known rage. + cidCursor := edgeIndex.Cursor() + + lastChanID, _ := cidCursor.Last() + + // If there's no key, then this means that we don't actually + // know of any channels, so we'll return a predicable error. + if lastChanID == nil { + return ErrGraphNoEdgesFound + } + + // Otherwise, we'll de serialize the channel ID and return it + // to the caller. + cid = byteOrder.Uint64(lastChanID) + return nil + }) + if err != nil && err != ErrGraphNoEdgesFound { + return 0, err + } + + return cid, nil +} + +// ChannelEdge represents the complete set of information for a channel edge in +// the known channel graph. This struct couples the core information of the +// edge as well as each of the known advertised edge policies. +type ChannelEdge struct { + // Info contains all the static information describing the channel. + Info *ChannelEdgeInfo + + // Policy1 points to the "first" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy1 *ChannelEdgePolicy + + // Policy2 points to the "second" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy2 *ChannelEdgePolicy +} + +// ChanUpdatesInHorizon returns all the known channel edges which have at least +// one edge that has an update timestamp within the specified horizon. +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { + var edgesInHorizon []ChannelEdge + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return ErrGraphNoEdgesFound + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all channels within the horizon. + updateCursor := edgeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 8]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting the info and policy of each update of + // each channel that has a last update within the time range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + // We have a new eligible entry, so we'll slice of the + // chan ID so we can query it in the DB. + chanID := indexKey[8:] + + // First, we'll fetch the static edge information. + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + return err + } + + // Finally, we'll collate this edge with the rest of + // edges to be returned. + edgesInHorizon = append(edgesInHorizon, ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + }) + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + return edgesInHorizon, nil +} + +// NodeUpdatesInHorizon returns all the known lightning node which have an +// update timestamp within the passed range. This method can be used by two +// nodes to quickly determine if they have the same set of up to date node +// announcements. +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode + + err := c.db.View(func(tx *bolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) + if nodeUpdateIndex == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all node announcements within the horizon. + updateCursor := nodeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 33]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting info for each node within the time + // range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + nodePub := indexKey[8:] + node, err := fetchLightningNode(nodes, nodePub) + if err != nil { + return err + } + node.db = c.db + + nodesInHorizon = append(nodesInHorizon, node) + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + return nodesInHorizon, nil +} + +// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan +// ID's that we don't know of in the passed set. In other words, we perform a +// set difference of our set of chan ID's and the ones passed in. This method +// can be used by callers to determine the set of channels ta peer knows of +// that we don't. +func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { + var newChanIDs []uint64 + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // We'll run through the set of chanIDs and collate only the + // set of channel that are unable to be found within our db. + var cidBytes [8]byte + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + if v := edgeIndex.Get(cidBytes[:]); v == nil { + newChanIDs = append(newChanIDs, cid) + } + } + + return nil + }) + switch { + // If we don't know of any edges yet, then we'll return the entire set + // of chan IDs specified. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return newChanIDs, nil +} + +// FilterChannelRange returns the channel ID's of all known channels which were +// mined in a block height within the passed range. This method can be used to +// quickly share with a peer the set of channels we know of within a particular +// range to catch them up after a period of time offline. +func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { + var chanIDs []uint64 + + startChanID := &lnwire.ShortChannelID{ + BlockHeight: startHeight, + } + + endChanID := lnwire.ShortChannelID{ + BlockHeight: endHeight, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + + // As we need to perform a range scan, we'll convert the starting and + // ending height to their corresponding values when encoded using short + // channel ID's. + var chanIDStart, chanIDEnd [8]byte + byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) + byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + cursor := edgeIndex.Cursor() + + // We'll now iterate through the database, and find each + // channel ID that resides within the specified range. + var cid uint64 + for k, _ := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { + + // This channel ID rests within the target range, so + // we'll convert it into an integer and add it to our + // returned set. + cid = byteOrder.Uint64(k) + chanIDs = append(chanIDs, cid) + } + + return nil + }) + switch { + // If we don't know of any channels yet, then there's nothing to + // filter, so we'll return an empty slice. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return chanIDs, nil +} + +// FetchChanInfos returns the set of channel edges that correspond to the +// passed channel ID's. This can be used to respond to peer queries that are +// seeking to fill in gaps in their view of the channel graph. +func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { + // TODO(roasbeef): sort cids? + + var ( + chanEdges []ChannelEdge + cidBytes [8]byte + ) + + err := c.db.View(func(tx *bolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + // First, we'll fetch the static edge information. + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, cidBytes[:], + ) + if err != nil { + return err + } + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, cidBytes[:], c.db, + ) + if err != nil { + return err + } + + chanEdges = append(chanEdges, ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + }) + } + return nil + }) + if err != nil { + return nil, err + } + + return chanEdges, nil +} + +func delEdgeUpdateIndexEntry(edgesBucket *bolt.Bucket, chanID uint64, + edge1, edge2 *ChannelEdgePolicy) error { + + // First, we'll fetch the edge update index bucket which currently + // stores an entry for the channel we're about to delete. + updateIndex, err := edgesBucket.CreateBucketIfNotExists( + edgeUpdateIndexBucket, + ) + if err != nil { + return err + } + + // Now that we have the bucket, we'll attempt to construct a template + // for the index key: updateTime || chanid. + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[8:], chanID) + + // With the template constructed, we'll attempt to delete an entry that + // would have been created by both edges: we'll alternate the update + // times, as one may had overridden the other. + if edge1 != nil { + byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix())) + if err := updateIndex.Delete(indexKey[:]); err != nil { + return err + } + } + + // We'll also attempt to delete the entry that may have been created by + // the second edge. + if edge2 != nil { + byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix())) + if err := updateIndex.Delete(indexKey[:]); err != nil { + return err + } + } + + return nil +} + func delChannelByEdge(edges *bolt.Bucket, edgeIndex *bolt.Bucket, - chanIndex *bolt.Bucket, chanPoint *wire.OutPoint) error { + chanIndex *bolt.Bucket, nodes *bolt.Bucket, chanPoint *wire.OutPoint) error { var b bytes.Buffer if err := writeOutpoint(&b, chanPoint); err != nil { return err } - // If the channel's outpoint doesn't exist within the outpoint - // index, then the edge does not exist. + // If the channel's outpoint doesn't exist within the outpoint index, + // then the edge does not exist. chanID := chanIndex.Get(b.Bytes()) if chanID == nil { return ErrEdgeNotFound } - // Otherwise we obtain the two public keys from the mapping: - // chanID -> pubKey1 || pubKey2. With this, we can construct - // the keys which house both of the directed edges for this - // channel. + // Otherwise we obtain the two public keys from the mapping: chanID -> + // pubKey1 || pubKey2. With this, we can construct the keys which house + // both of the directed edges for this channel. nodeKeys := edgeIndex.Get(chanID) if nodeKeys == nil { return fmt.Errorf("could not find nodekeys for chanID %v", chanID) } - // The edge key is of the format pubKey || chanID. First we - // construct the latter half, populating the channel ID. + // The edge key is of the format pubKey || chanID. First we construct + // the latter half, populating the channel ID. var edgeKey [33 + 8]byte copy(edgeKey[33:], chanID) - // With the latter half constructed, copy over the first public - // key to delete the edge in this direction, then the second to - // delete the edge in the opposite direction. + // With the latter half constructed, copy over the first public key to + // delete the edge in this direction, then the second to delete the + // edge in the opposite direction. copy(edgeKey[:33], nodeKeys[:33]) if edges.Get(edgeKey[:]) != nil { if err := edges.Delete(edgeKey[:]); err != nil { @@ -917,8 +1330,21 @@ func delChannelByEdge(edges *bolt.Bucket, edgeIndex *bolt.Bucket, } } - // Finally, with the edge data deleted, we can purge the - // information from the two edge indexes. + // We'll also remove the entry in the edge update index bucket. + cid := byteOrder.Uint64(chanID) + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, nil, + ) + if err != nil { + return err + } + err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2) + if err != nil { + return err + } + + // Finally, with the edge data deleted, we can purge the information + // from the two edge indexes. if err := edgeIndex.Delete(chanID); err != nil { return err } @@ -1776,7 +2202,9 @@ func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { return &ChannelEdgePolicy{db: c.db} } -func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, node *LightningNode) error { +func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, + updateIndex *bolt.Bucket, node *LightningNode) error { + var ( scratch [16]byte b bytes.Buffer @@ -1803,8 +2231,8 @@ func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, node *L return err } - // If we got a node announcement for this node, we will have the rest of - // the data available. If not we don't have more data to write. + // If we got a node announcement for this node, we will have the rest + // of the data available. If not we don't have more data to write. if !node.HaveNodeAnnouncement { // Write HaveNodeAnnouncement=0. byteOrder.PutUint16(scratch[:2], 0) @@ -1860,8 +2288,33 @@ func putLightningNode(nodeBucket *bolt.Bucket, aliasBucket *bolt.Bucket, node *L return err } - return nodeBucket.Put(nodePub, b.Bytes()) + // With the alias bucket updated, we'll now update the index that + // tracks the time series of node updates. + var indexKey [8 + 33]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + copy(indexKey[8:], nodePub) + // If there was already an old index entry for this node, then we'll + // delete the old one before we write the new entry. + if nodeBytes := nodeBucket.Get(nodePub); nodeBytes != nil { + // Extract out the old update time to we can reconstruct the + // prior index key to delete it from the index. + oldUpdateTime := nodeBytes[:8] + + var oldIndexKey [8 + 33]byte + copy(oldIndexKey[:8], oldUpdateTime) + copy(oldIndexKey[8:], nodePub) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + return nodeBucket.Put(nodePub, b.Bytes()) } func fetchLightningNode(nodeBucket *bolt.Bucket, @@ -2136,6 +2589,44 @@ func putChanEdgePolicy(edges *bolt.Bucket, edge *ChannelEdgePolicy, from, to []b return err } + // Before we write out the new edge, we'll create a new entry in the + // update index in order to keep it fresh. + var indexKey [8 + 8]byte + copy(indexKey[:], scratch[:]) + byteOrder.PutUint64(indexKey[8:], edge.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + // If there was already an entry for this edge, then we'll need to + // delete the old one to ensure we don't leave around any after-images. + if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil { + // In order to delete the old entry, we'll need to obtain the + // *prior* update time in order to delete it. To do this, we'll + // create an offset to slice in. Starting backwards, we'll + // create an offset than puts us right after the flags + // variable: + // + // * pubkeySize + fee+policySize + timelockSize + flagSize + updateEnd := 33 + (8 * 3) + 2 + 1 + updateStart := updateEnd - 8 + oldUpdateTime := edgeBytes[updateStart:updateEnd] + + var oldIndexKey [8 + 8]byte + copy(oldIndexKey[:], oldUpdateTime) + byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + return edges.Put(edgeKey[:], b.Bytes()[:]) } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 3e3327ea..591811be 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -373,6 +373,42 @@ func TestEdgeInsertionDeletion(t *testing.T) { } } +func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, + node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { + + shortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + TxIndex: txIndex, + TxPosition: txPosition, + } + outpoint := wire.OutPoint{ + Hash: rev, + Index: outPointIndex, + } + + node1Pub, _ := node1.PubKey() + node2Pub, _ := node2.PubKey() + edgeInfo := ChannelEdgeInfo{ + ChannelID: shortChanID.ToUint64(), + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + + copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) + + return edgeInfo, shortChanID +} + // TestDisconnectBlockAtHeight checks that the pruned state of the channel // database is what we expect after calling DisconnectBlockAtHeight. func TestDisconnectBlockAtHeight(t *testing.T) { @@ -419,54 +455,22 @@ func TestDisconnectBlockAtHeight(t *testing.T) { // We'll create 3 almost identical edges, so first create a helper // method containing all logic for doing so. - createEdge := func(height uint32, txIndex uint32, txPosition uint16, - outPointIndex uint32) ChannelEdgeInfo { - shortChanID := lnwire.ShortChannelID{ - BlockHeight: height, - TxIndex: txIndex, - TxPosition: txPosition, - } - outpoint := wire.OutPoint{ - Hash: rev, - Index: outPointIndex, - } - - node1Pub, _ := node1.PubKey() - node2Pub, _ := node2.PubKey() - edgeInfo := ChannelEdgeInfo{ - ChannelID: shortChanID.ToUint64(), - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 9000, - } - - copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - - return edgeInfo - } // Create an edge which has its block height at 156. height := uint32(156) - edgeInfo := createEdge(height, 0, 0, 0) + edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2) // Create an edge with block height 157. We give it // maximum values for tx index and position, to make // sure our database range scan get edges from the // entire range. - edgeInfo2 := createEdge(height+1, math.MaxUint32&0x00ffffff, - math.MaxUint16, 1) + edgeInfo2, _ := createEdge( + height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1, + node1, node2, + ) // Create a third edge, this with a block height of 155. - edgeInfo3 := createEdge(height-1, 0, 0, 2) + edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) // Now add all these new edges to the database. if err := graph.AddChannelEdge(&edgeInfo); err != nil { @@ -754,9 +758,15 @@ func TestEdgeInfoUpdates(t *testing.T) { func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { update := prand.Int63() + return newEdgePolicy(chanID, op, db, update) +} + +func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, + updateTime int64) *ChannelEdgePolicy { + return &ChannelEdgePolicy{ ChannelID: chanID, - LastUpdate: time.Unix(update, 0), + LastUpdate: time.Unix(updateTime, 0), TimeLockDelta: uint16(prand.Int63()), MinHTLC: lnwire.MilliSatoshi(prand.Int63()), FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), @@ -1164,13 +1174,711 @@ func TestGraphPruning(t *testing.T) { } } +// TestHighestChanID tests that we're able to properly retrieve the highest +// known channel ID in the database. +func TestHighestChanID(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we don't yet have any channels in the database, then we should + // get a channel ID of zero if we ask for the highest channel ID. + bestID, err := graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + if bestID != 0 { + t.Fatalf("best ID w/ no chan should be zero, is instead: %v", + bestID) + } + + // Next, we'll insert two channels into the database, with each channel + // connecting the same two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // The first channel with be at height 10, while the other will be at + // height 100. + edge1, _ := createEdge(10, 0, 0, 0, node1, node2) + edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) + + if err := graph.AddChannelEdge(&edge1); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + if err := graph.AddChannelEdge(&edge2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Now that the edges has been inserted, we'll query for the highest + // known channel ID in the database. + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID2.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID2.ToUint64(), bestID) + } + + // If we add another edge, then the current best chan ID should be + // updated as well. + edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edge3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID3.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID3.ToUint64(), bestID) + } +} + +// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known +// channel updates within a specific time horizon. It also tests that upon +// insertion of a new edge, the edge update index is updated properly. +func TestChanUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we issue an arbitrary query before any channel updates are + // inserted in the database, we should get zero results. + chanUpdates, err := graph.ChanUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to updates for updates: %v", err) + } + if len(chanUpdates) != 0 { + t.Fatalf("expected 0 chan updates, instead got %v", + len(chanUpdates)) + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll now create 10 channels between the two nodes, with update + // times 10 seconds after each other. + const numChans = 10 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge1.Flags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge2.Flags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + } + + // With our channels loaded, we'll now start our series of queries. + queryCases := []struct { + start time.Time + end time.Time + + resp []ChannelEdge + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we query for the start time, and 10 seconds directly + // after it, we should only get a single update, that first + // one. + { + start: time.Unix(1234, 0), + end: startTime.Add(time.Second * 10), + + resp: []ChannelEdge{edges[0]}, + }, + + // If we add 10 seconds past the first update, and then + // subtract 10 from the last update, then we should only get + // the 8 edges in the middle. + { + start: startTime.Add(time.Second * 10), + end: endTime.Add(-time.Second * 10), + + resp: edges[1:9], + }, + + // If we use the start and end time as is, we should get the + // entire range. + { + start: startTime, + end: endTime, + + resp: edges, + }, + } + for _, queryCase := range queryCases { + resp, err := graph.ChanUpdatesInHorizon( + queryCase.start, queryCase.end, + ) + if err != nil { + t.Fatalf("unable to query for updates: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v chans, got %v chans", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + chanExp := queryCase.resp[i] + chanRet := resp[i] + + assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) + + err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) + if err != nil { + t.Fatal(err) + } + compareEdgePolicies(chanExp.Policy2, chanRet.Policy2) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve +// the most recent node updates within a particular time horizon. +func TestNodeUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + startTime := time.Unix(1234, 0) + endTime := startTime + + // If we issue an arbitrary query before we insert any nodes into the + // database, then we shouldn't get any results back. + nodeUpdates, err := graph.NodeUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to query for node updates: %v", err) + } + if len(nodeUpdates) != 0 { + t.Fatalf("expected 0 node updates, instead got %v", + len(nodeUpdates)) + } + + // We'll create 10 node announcements, each with an update timestmap 10 + // seconds after the other. + const numNodes = 10 + nodeAnns := make([]LightningNode, 0, numNodes) + for i := 0; i < numNodes; i++ { + nodeAnn, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + + // The node ann will use the current end time as its last + // update them, then we'll add 10 seconds in order to create + // the proper update time for the next node announcement. + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + nodeAnn.LastUpdate = updateTime + + nodeAnns = append(nodeAnns, *nodeAnn) + + if err := graph.AddLightningNode(nodeAnn); err != nil { + t.Fatalf("unable to add lightning node: %v", err) + } + } + + queryCases := []struct { + start time.Time + end time.Time + + resp []LightningNode + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we skip he first time epoch with out start time, then we + // should get back every now but the first. + { + start: startTime.Add(time.Second * 10), + end: endTime, + + resp: nodeAnns[1:], + }, + + // If we query for the range as is, we should get all 10 + // announcements back. + { + start: startTime, + end: endTime, + + resp: nodeAnns, + }, + + // If we reduce the ending time by 10 seconds, then we should + // get all but the last node we inserted. + { + start: startTime, + end: endTime.Add(-time.Second * 10), + + resp: nodeAnns[:9], + }, + } + for _, queryCase := range queryCases { + resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) + if err != nil { + t.Fatalf("unable to query for nodes: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v nodes, got %v nodes", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + err := compareNodes(&queryCase.resp[i], &resp[i]) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestFilterKnownChanIDs tests that we're able to properly perform the set +// differences of an incoming set of channel ID's, and those that we already +// know of on disk. +func TestFilterKnownChanIDs(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we try to filter out a set of channel ID's before we even know of + // any channels, then we should get the entire set back. + preChanIDs := []uint64{1, 2, 3, 4} + filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + if !reflect.DeepEqual(preChanIDs, filteredIDs) { + t.Fatalf("chan IDs shouldn't have been filtered!") + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, we'll add 5 channel ID's to the graph, each of them having a + // block height 10 blocks after the previous. + const numChans = 5 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + } + + queryCases := []struct { + queryIDs []uint64 + + resp []uint64 + }{ + // If we attempt to filter out all chanIDs we know of, the + // response should be the empty set. + { + queryIDs: chanIDs, + }, + + // If we query for a set of ID's that we didn't insert, we + // should get the same set back. + { + queryIDs: []uint64{99, 100}, + resp: []uint64{99, 100}, + }, + + // If we query for a super-set of our the chan ID's inserted, + // we should only get those new chanIDs back. + { + queryIDs: append(chanIDs, []uint64{99, 101}...), + resp: []uint64{99, 101}, + }, + } + + for _, queryCase := range queryCases { + resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), + spew.Sdump(resp)) + } + } +} + +// TestFilterChannelRange tests that we're able to properly retrieve the full +// set of short channel ID's for a given block range. +func TestFilterChannelRange(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // If we try to filter a channel range before we have any channels + // inserted, we should get an empty slice of results. + resp, err := graph.FilterChannelRange(10, 100) + if err != nil { + t.Fatalf("unable to filter channels: %v", err) + } + if len(resp) != 0 { + t.Fatalf("expected zero chans, instead got %v", len(resp)) + } + + // To start, we'll create a set of channels, each mined in a block 10 + // blocks after the prior one. + startHeight := uint32(100) + endHeight := startHeight + const numChans = 10 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + chanHeight := endHeight + channel, chanID := createEdge( + uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + + endHeight += 10 + } + + // With our channels inserted, we'll construct a series of queries that + // we'll execute below in order to exercise the features of the + // FilterKnownChanIDs method. + queryCases := []struct { + startHeight uint32 + endHeight uint32 + + resp []uint64 + }{ + // If we query for the entire range, then we should get the same + // set of short channel IDs back. + { + startHeight: startHeight, + endHeight: endHeight, + + resp: chanIDs, + }, + + // If we query for a range of channels right before our range, we + // shouldn't get any results back. + { + startHeight: 0, + endHeight: 10, + }, + + // If we only query for the last height (range wise), we should + // only get that last channel. + { + startHeight: endHeight - 10, + endHeight: endHeight - 10, + + resp: chanIDs[9:], + }, + + // If we query for just the first height, we should only get a + // single channel back (the first one). + { + startHeight: startHeight, + endHeight: startHeight, + + resp: chanIDs[:1], + }, + } + for i, queryCase := range queryCases { + resp, err := graph.FilterChannelRange( + queryCase.startHeight, queryCase.endHeight, + ) + if err != nil { + t.Fatalf("unable to issue range query: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("case #%v: expected %v, got %v", i, + queryCase.resp, resp) + } + } +} + +// TestFetchChanInfos tests that we're able to properly retrieve the full set +// of ChannelEdge structs for a given set of short channel ID's. +func TestFetchChanInfos(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll make 5 test channels, ensuring we keep track of which channel + // ID corresponds to a particular ChannelEdge. + const numChans = 5 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + edgeQuery := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge1.Flags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge2.Flags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + + edgeQuery = append(edgeQuery, chanID.ToUint64()) + } + + // We'll now attempt to query for the range of channel ID's we just + // inserted into the database. We should get the exact same set of + // edges back. + resp, err := graph.FetchChanInfos(edgeQuery) + if err != nil { + t.Fatalf("unable to fetch chan edges: %v", err) + } + if len(resp) != len(edges) { + t.Fatalf("expected %v edges, instead got %v", len(edges), + len(resp)) + } + + for i := 0; i < len(resp); i++ { + err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info) + } +} + // compareNodes is used to compare two LightningNodes while excluding the // Features struct, which cannot be compared as the semantics for reserializing // the featuresMap have not been defined. func compareNodes(a, b *LightningNode) error { - if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("LastUpdate doesn't match: expected %#v, \n"+ - "got %#v", a.LastUpdate, b.LastUpdate) + if a.LastUpdate != b.LastUpdate { + return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ + "got %v", a.LastUpdate, b.LastUpdate) } if !reflect.DeepEqual(a.Addresses, b.Addresses) { return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ @@ -1208,7 +1916,7 @@ func compareEdgePolicies(a, b *ChannelEdgePolicy) error { "got %v", a.ChannelID, b.ChannelID) } if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("LastUpdate doesn't match: expected %#v, \n "+ + return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ "got %#v", a.LastUpdate, b.LastUpdate) } if a.Flags != b.Flags { diff --git a/channeldb/migrations.go b/channeldb/migrations.go index d03b3406..50ceec79 100644 --- a/channeldb/migrations.go +++ b/channeldb/migrations.go @@ -1 +1,114 @@ package channeldb + +import ( + "bytes" + "fmt" + + "github.com/coreos/bbolt" +) + +// migrateNodeAndEdgeUpdateIndex is a migration function that will update the +// database from version 0 to version 1. In version 1, we add two new indexes +// (one for nodes and one for edges) to keep track of the last time a node or +// edge was updated on the network. These new indexes allow us to implement the +// new graph sync protocol added. +func migrateNodeAndEdgeUpdateIndex(tx *bolt.Tx) error { + // First, we'll populating the node portion of the new index. Before we + // can add new values to the index, we'll first create the new bucket + // where these items will be housed. + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return fmt.Errorf("unable to create node bucket: %v", err) + } + nodeUpdateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return fmt.Errorf("unable to create node update index: %v", err) + } + + log.Infof("Populating new node update index bucket") + + // Now that we know the bucket has been created, we'll iterate over the + // entire node bucket so we can add the (updateTime || nodePub) key + // into the node update index. + err = nodes.ForEach(func(nodePub, nodeInfo []byte) error { + if len(nodePub) != 33 { + return nil + } + + log.Tracef("Adding %x to node update index", nodePub) + + // The first 8 bytes of a node's serialize data is the update + // time, so we can extract that without decoding the entire + // structure. + updateTime := nodeInfo[:8] + + // Now that we have the update time, we can construct the key + // to insert into the index. + var indexKey [8 + 33]byte + copy(indexKey[:8], updateTime) + copy(indexKey[8:], nodePub) + + return nodeUpdateIndex.Put(indexKey[:], nil) + }) + if err != nil { + return fmt.Errorf("unable to update node indexes: %v", err) + } + + log.Infof("Populating new edge update index bucket") + + // With the set of nodes updated, we'll now update all edges to have a + // corresponding entry in the edge update index. + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return fmt.Errorf("unable to create edge bucket: %v", err) + } + edgeUpdateIndex, err := edges.CreateBucketIfNotExists( + edgeUpdateIndexBucket, + ) + if err != nil { + return fmt.Errorf("unable to create edge update index: %v", err) + } + + // We'll now run through each edge policy in the database, and update + // the index to ensure each edge has the proper record. + err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { + if len(edgeKey) != 41 { + return nil + } + + // Now that we know this is the proper record, we'll grab the + // channel ID (last 8 bytes of the key), and then decode the + // edge policy so we can access the update time. + chanID := edgeKey[33:] + edgePolicyReader := bytes.NewReader(edgePolicyBytes) + + edgePolicy, err := deserializeChanEdgePolicy( + edgePolicyReader, nodes, + ) + if err != nil { + return err + } + + log.Tracef("Adding chan_id=%v to edge update index", + edgePolicy.ChannelID) + + // We'll now construct the index key using the channel ID, and + // the last time it was updated: (updateTime || chanID). + var indexKey [8 + 8]byte + byteOrder.PutUint64( + indexKey[:], uint64(edgePolicy.LastUpdate.Unix()), + ) + copy(indexKey[8:], chanID) + + return edgeUpdateIndex.Put(indexKey[:], nil) + }) + if err != nil { + return fmt.Errorf("unable to update edge indexes: %v", err) + } + + log.Infof("Migration to node and edge update indexes complete!") + + return nil +} diff --git a/config.go b/config.go index 4ed9fed9..4cfbb848 100644 --- a/config.go +++ b/config.go @@ -204,6 +204,8 @@ type config struct { Color string `long:"color" description:"The color of the node in hex format (i.e. '#3399FF'). Used to customize node appearance in intelligence services"` MinChanSize int64 `long:"minchansize" description:"The smallest channel size (in satoshis) that we should accept. Incoming channels smaller than this will be rejected"` + NoChanUpdates bool `long:"nochanupdates" description:"If specified, lnd will not request real-time channel updates from connected peers. This option should be used by routing nodes to save bandwidth."` + net torsvc.Net } diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 90b84590..878a9cd3 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -72,6 +72,12 @@ type Config struct { // order to be included in the LN graph. Router routing.ChannelGraphSource + // ChanSeries is an interfaces that provides access to a time series + // view of the current known channel graph. Each gossipSyncer enabled + // peer will utilize this in order to create and respond to channel + // graph time series queries. + ChanSeries ChannelGraphTimeSeries + // Notifier is used for receiving notifications of incoming blocks. // With each new incoming block found we process previously premature // announcements. @@ -196,6 +202,14 @@ type AuthenticatedGossiper struct { rejectMtx sync.RWMutex recentRejects map[uint64]struct{} + // peerSyncers keeps track of all the gossip syncers we're maintain for + // peers that understand this mode of operation. When we go to send out + // new updates, for all peers in the map, we'll send the messages + // directly to their gossiper, rather than broadcasting them. With this + // change, we ensure we filter out all updates properly. + syncerMtx sync.RWMutex + peerSyncers map[routing.Vertex]*gossipSyncer + sync.Mutex } @@ -218,6 +232,7 @@ func New(cfg Config, selfKey *btcec.PublicKey) (*AuthenticatedGossiper, error) { waitingProofs: storage, channelMtx: multimutex.NewMutex(), recentRejects: make(map[uint64]struct{}), + peerSyncers: make(map[routing.Vertex]*gossipSyncer), }, nil } @@ -254,6 +269,11 @@ func (d *AuthenticatedGossiper) SynchronizeNode(pub *btcec.PublicKey) error { }, nil } + // We'll use this map to ensure we don't send the same node + // announcement more than one time as one node may have many channel + // anns we'll need to send. + nodePubsSent := make(map[routing.Vertex]struct{}) + // As peers are expecting channel announcements before node // announcements, we first retrieve the initial announcement, as well as // the latest channel update announcement for both of the directed edges @@ -271,7 +291,7 @@ func (d *AuthenticatedGossiper) SynchronizeNode(pub *btcec.PublicKey) error { // also has known validated nodes, then we'll send that as // well. if chanInfo.AuthProof != nil { - chanAnn, e1Ann, e2Ann, err := createChanAnnouncement( + chanAnn, e1Ann, e2Ann, err := CreateChanAnnouncement( chanInfo.AuthProof, chanInfo, e1, e2, ) if err != nil { @@ -283,15 +303,21 @@ func (d *AuthenticatedGossiper) SynchronizeNode(pub *btcec.PublicKey) error { announceMessages = append(announceMessages, e1Ann) // If this edge has a validated node - // announcement, then we'll send that as well. - if e1.Node.HaveNodeAnnouncement { + // announcement, that we haven't yet sent, then + // we'll send that as well. + nodePub := e1.Node.PubKeyBytes + hasNodeAnn := e1.Node.HaveNodeAnnouncement + if _, ok := nodePubsSent[nodePub]; !ok && hasNodeAnn { nodeAnn, err := makeNodeAnn(e1.Node) if err != nil { return err } + announceMessages = append( announceMessages, nodeAnn, ) + nodePubsSent[nodePub] = struct{}{} + numNodes++ } } @@ -299,15 +325,21 @@ func (d *AuthenticatedGossiper) SynchronizeNode(pub *btcec.PublicKey) error { announceMessages = append(announceMessages, e2Ann) // If this edge has a validated node - // announcement, then we'll send that as well. - if e2.Node.HaveNodeAnnouncement { + // announcement, that we haven't yet sent, then + // we'll send that as well. + nodePub := e2.Node.PubKeyBytes + hasNodeAnn := e2.Node.HaveNodeAnnouncement + if _, ok := nodePubsSent[nodePub]; !ok && hasNodeAnn { nodeAnn, err := makeNodeAnn(e2.Node) if err != nil { return err } + announceMessages = append( announceMessages, nodeAnn, ) + nodePubsSent[nodePub] = struct{}{} + numNodes++ } } @@ -400,10 +432,19 @@ func (d *AuthenticatedGossiper) Stop() { log.Info("Authenticated Gossiper is stopping") + d.syncerMtx.RLock() + for _, syncer := range d.peerSyncers { + syncer.Stop() + } + d.syncerMtx.RUnlock() + close(d.quit) d.wg.Wait() } +// TODO(roasbeef): need method to get current gossip timestamp? +// * using mtx, check time rotate forward is needed? + // ProcessRemoteAnnouncement sends a new remote announcement message along with // the peer that sent the routing message. The announcement will be processed // then added to a queue for batched trickled announcement to all connected @@ -480,6 +521,16 @@ type msgWithSenders struct { senders map[routing.Vertex]struct{} } +// mergeSyncerMap is used to merge the set of senders of a particular message +// with peers that we have an active gossipSyncer with. We do this to ensure +// that we don't broadcast messages to any peers that we have active gossip +// syncers for. +func (m *msgWithSenders) mergeSyncerMap(syncers map[routing.Vertex]struct{}) { + for peerPub := range syncers { + m.senders[peerPub] = struct{}{} + } +} + // deDupedAnnouncements de-duplicates announcements that have been added to the // batch. Internally, announcements are stored in three maps // (one each for channel announcements, channel updates, and node @@ -693,12 +744,11 @@ func (d *deDupedAnnouncements) Emit() []msgWithSenders { return msgs } -// resendAnnounceSignatures will inspect the messageStore database -// bucket for AnnounceSignatures messages that we recently tried -// to send to a peer. If the associated channels still not have the -// full channel proofs assembled, we will try to resend them. If -// we have the full proof, we can safely delete the message from -// the messageStore. +// resendAnnounceSignatures will inspect the messageStore database bucket for +// AnnounceSignatures messages that we recently tried to send to a peer. If the +// associated channels still not have the full channel proofs assembled, we +// will try to resend them. If we have the full proof, we can safely delete the +// message from the messageStore. func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { type msgTuple struct { peer *btcec.PublicKey @@ -706,8 +756,9 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { dbKey []byte } - // Fetch all the AnnounceSignatures messages that was added - // to the database. + // Fetch all the AnnounceSignatures messages that was added to the + // database. + // // TODO(halseth): database access should be abstracted // behind interface. var msgsResend []msgTuple @@ -717,7 +768,6 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { return nil } - // Iterate over each message added to the database. if err := bucket.ForEach(func(k, v []byte) error { // The database value represents the encoded // AnnounceSignatures message. @@ -727,17 +777,16 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { return err } - // The first 33 bytes of the database key is - // the peer's public key. + // The first 33 bytes of the database key is the peer's + // public key. peer, err := btcec.ParsePubKey(k[:33], btcec.S256()) if err != nil { return err } t := msgTuple{peer, msg, k} - // Add the message to the slice, such that we - // can resend it after the database transaction - // is over. + // Add the message to the slice, such that we can + // resend it after the database transaction is over. msgsResend = append(msgsResend, t) return nil }); err != nil { @@ -748,8 +797,8 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { return err } - // deleteMsg removes the message associated with the passed - // msgTuple from the messageStore. + // deleteMsg removes the message associated with the passed msgTuple + // from the messageStore. deleteMsg := func(t msgTuple) error { log.Debugf("Deleting message for chanID=%v from "+ "messageStore", t.msg.ChannelID) @@ -768,16 +817,16 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { return nil } - // We now iterate over these messages, resending those that we - // don't have the full proof for, deleting the rest. + // We now iterate over these messages, resending those that we don't + // have the full proof for, deleting the rest. for _, t := range msgsResend { // Check if the full channel proof exists in our graph. chanInfo, _, _, err := d.cfg.Router.GetChannelByID( t.msg.ShortChannelID) if err != nil { - // If the channel cannot be found, it is most likely - // a leftover message for a channel that was closed. - // In this case we delete it from the message store. + // If the channel cannot be found, it is most likely a + // leftover message for a channel that was closed. In + // this case we delete it from the message store. log.Warnf("unable to fetch channel info for "+ "chanID=%v from graph: %v. Will delete local"+ "proof from database", @@ -788,13 +837,12 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { continue } - // 1. If the full proof does not exist in the graph, - // it means that we haven't received the remote proof - // yet (or that we crashed before able to assemble the - // full proof). Since the remote node might think they - // have delivered their proof to us, we will resend - // _our_ proof to trigger a resend on their part: - // they will then be able to assemble and send us the + // 1. If the full proof does not exist in the graph, it means + // that we haven't received the remote proof yet (or that we + // crashed before able to assemble the full proof). Since the + // remote node might think they have delivered their proof to + // us, we will resend _our_ proof to trigger a resend on their + // part: they will then be able to assemble and send us the // full proof. if chanInfo.AuthProof == nil { err := d.sendAnnSigReliably(t.msg, t.peer) @@ -805,13 +853,12 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { } // 2. If the proof does exist in the graph, we have - // successfully received the remote proof and assembled - // the full proof. In this case we can safely delete the - // local proof from the database. In case the remote - // hasn't been able to assemble the full proof yet - // (maybe because of a crash), we will send them the full - // proof if we notice that they retry sending their half - // proof. + // successfully received the remote proof and assembled the + // full proof. In this case we can safely delete the local + // proof from the database. In case the remote hasn't been able + // to assemble the full proof yet (maybe because of a crash), + // we will send them the full proof if we notice that they + // retry sending their half proof. if chanInfo.AuthProof != nil { log.Debugf("Deleting message for chanID=%v from "+ "messageStore", t.msg.ChannelID) @@ -823,6 +870,52 @@ func (d *AuthenticatedGossiper) resendAnnounceSignatures() error { return nil } +// findGossipSyncer is a utility method used by the gossiper to locate the +// gossip syncer for an inbound message so we can properly dispatch the +// incoming message. If a gossip syncer isn't found, then one will be created +// for the target peer. +func (d *AuthenticatedGossiper) findGossipSyncer(pub *btcec.PublicKey) *gossipSyncer { + target := routing.NewVertex(pub) + + // First, we'll try to find an existing gossiper for this peer. + d.syncerMtx.RLock() + syncer, ok := d.peerSyncers[target] + d.syncerMtx.RUnlock() + + // If one exists, then we'll return it directly. + if ok { + return syncer + } + + // Otherwise, we'll obtain the mutex, then check again if a gossiper + // was added after we dropped the read mutex. + d.syncerMtx.Lock() + syncer, ok = d.peerSyncers[target] + if ok { + d.syncerMtx.Unlock() + return syncer + } + + // At this point, a syncer doesn't yet exist, so we'll create a new one + // for the peer and return it to the caller. + syncer = newGossiperSyncer(gossipSyncerCfg{ + chainHash: d.cfg.ChainHash, + syncChanUpdates: true, + channelSeries: d.cfg.ChanSeries, + encodingType: lnwire.EncodingSortedPlain, + sendToPeer: func(msgs ...lnwire.Message) error { + return d.cfg.SendToPeer(pub, msgs...) + }, + }) + copy(syncer.peerPub[:], pub.SerializeCompressed()) + d.peerSyncers[target] = syncer + syncer.Start() + + d.syncerMtx.Unlock() + + return syncer +} + // networkHandler is the primary goroutine that drives this service. The roles // of this goroutine includes answering queries related to the state of the // network, syncing up newly connected peers, and also periodically @@ -880,9 +973,10 @@ func (d *AuthenticatedGossiper) networkHandler() { policyUpdate.errResp <- nil case announcement := <-d.networkMsgs: - // Channel announcement signatures are the only message - // that we'll process serially. - if _, ok := announcement.msg.(*lnwire.AnnounceSignatures); ok { + switch msg := announcement.msg.(type) { + // Channel announcement signatures are amongst the only + // messages that we'll process serially. + case *lnwire.AnnounceSignatures: emittedAnnouncements := d.processNetworkAnnouncement( announcement, ) @@ -892,6 +986,35 @@ func (d *AuthenticatedGossiper) networkHandler() { ) } continue + + // If a peer is updating its current update horizon, + // then we'll dispatch that directly to the proper + // gossipSyncer. + case *lnwire.GossipTimestampRange: + syncer := d.findGossipSyncer(announcement.peer) + + // If we've found the message target, then + // we'll dispatch the message directly to it. + err := syncer.ApplyGossipFilter(msg) + if err != nil { + log.Warnf("unable to apply gossip "+ + "filter for peer=%x: %v", + announcement.peer.SerializeCompressed(), err) + } + continue + + // For messages in the known set of channel series + // queries, we'll dispatch the message directly to the + // peer, and skip the main processing loop. + case *lnwire.QueryShortChanIDs, + *lnwire.QueryChannelRange, + *lnwire.ReplyChannelRange, + *lnwire.ReplyShortChanIDsEnd: + + syncer := d.findGossipSyncer(announcement.peer) + + syncer.ProcessQueryMsg(announcement.msg) + continue } // If this message was recently rejected, then we won't @@ -1003,12 +1126,37 @@ func (d *AuthenticatedGossiper) networkHandler() { continue } + // For the set of peers that have an active gossip + // syncers, we'll collect their pubkeys so we can avoid + // sending them the full message blast below. + d.syncerMtx.RLock() + syncerPeers := map[routing.Vertex]struct{}{} + for peerPub := range d.peerSyncers { + syncerPeers[peerPub] = struct{}{} + } + d.syncerMtx.RUnlock() + log.Infof("Broadcasting batch of %v new announcements", len(announcementBatch)) - // If we have new things to announce then broadcast - // them to all our immediately connected peers. + // We'll first attempt to filter out this new message + // for all peers that have active gossip syncers + // active. + d.syncerMtx.RLock() + for _, syncer := range d.peerSyncers { + syncer.FilterGossipMsgs(announcementBatch...) + } + d.syncerMtx.RUnlock() + + // Next, If we have new things to announce then + // broadcast them to all our immediately connected + // peers. for _, msgChunk := range announcementBatch { + // With the syncers taken care of, we'll merge + // the sender map with the set of syncers, so + // we don't send out duplicate messages. + msgChunk.mergeSyncerMap(syncerPeers) + err := d.cfg.Broadcast( msgChunk.senders, msgChunk.msg, ) @@ -1038,6 +1186,67 @@ func (d *AuthenticatedGossiper) networkHandler() { } } +// TODO(roasbeef): d/c peers that send uupdates not on our chain + +// InitPeerSyncState is called by outside sub-systems when a connection is +// established to a new peer that understands how to perform channel range +// queries. We'll allocate a new gossip syncer for it, and start any goroutines +// needed to handle new queries. The recvUpdates bool indicates if we should +// continue to receive real-time updates from the remote peer once we've synced +// channel state. +func (d *AuthenticatedGossiper) InitSyncState(peer *btcec.PublicKey, recvUpdates bool) { + d.syncerMtx.Lock() + defer d.syncerMtx.Unlock() + + // If we already have a syncer, then we'll exit early as we don't want + // to override it. + nodeID := routing.NewVertex(peer) + if _, ok := d.peerSyncers[nodeID]; ok { + return + } + + log.Infof("Creating new gossipSyncer for peer=%x", + peer.SerializeCompressed()) + + syncer := newGossiperSyncer(gossipSyncerCfg{ + chainHash: d.cfg.ChainHash, + syncChanUpdates: recvUpdates, + channelSeries: d.cfg.ChanSeries, + encodingType: lnwire.EncodingSortedPlain, + sendToPeer: func(msgs ...lnwire.Message) error { + return d.cfg.SendToPeer(peer, msgs...) + }, + }) + copy(syncer.peerPub[:], peer.SerializeCompressed()) + d.peerSyncers[nodeID] = syncer + + syncer.Start() +} + +// PruneSyncState is called by outside sub-systems once a peer that we were +// previously connected to has been disconnected. In this case we can stop the +// existing gossipSyncer assigned to the peer and free up resources. +func (d *AuthenticatedGossiper) PruneSyncState(peer *btcec.PublicKey) { + d.syncerMtx.Lock() + defer d.syncerMtx.Unlock() + + log.Infof("Removing gossipSyncer for peer=%x", + peer.SerializeCompressed()) + + vertex := routing.NewVertex(peer) + + syncer, ok := d.peerSyncers[routing.NewVertex(peer)] + if !ok { + return + } + + syncer.Stop() + + delete(d.peerSyncers, vertex) + + return +} + // isRecentlyRejectedMsg returns true if we recently rejected a message, and // false otherwise, This avoids expensive reprocessing of the message. func (d *AuthenticatedGossiper) isRecentlyRejectedMsg(msg lnwire.Message) bool { @@ -1265,7 +1474,7 @@ func (d *AuthenticatedGossiper) processRejectedEdge(chanAnnMsg *lnwire.ChannelAn // We'll then create then validate the new fully assembled // announcement. - chanAnn, e1Ann, e2Ann, err := createChanAnnouncement( + chanAnn, e1Ann, e2Ann, err := CreateChanAnnouncement( proof, chanInfo, e1, e2, ) if err != nil { @@ -1686,12 +1895,28 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n d.pChanUpdMtx.Lock() d.prematureChannelUpdates[shortChanID] = append( d.prematureChannelUpdates[shortChanID], - nMsg) + nMsg, + ) d.pChanUpdMtx.Unlock() + log.Debugf("Got ChannelUpdate for edge not "+ "found in graph(shortChanID=%v), "+ "saving for reprocessing later", shortChanID) + + // If the node supports it, we may try to + // request the chan ann from it. + go func() { + reqErr := d.maybeRequestChanAnn( + msg.ShortChannelID, + ) + if reqErr != nil { + log.Errorf("unable to request ann "+ + "for chan_id=%v: %v", shortChanID, + reqErr) + } + }() + nMsg.err <- nil return nil default: @@ -1921,7 +2146,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n msg.ChannelID, peerID) - chanAnn, _, _, err := createChanAnnouncement( + chanAnn, _, _, err := CreateChanAnnouncement( chanInfo.AuthProof, chanInfo, e1, e2, ) if err != nil { @@ -1996,7 +2221,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n dbProof.BitcoinSig1Bytes = oppositeProof.BitcoinSignature.ToSignatureBytes() dbProof.BitcoinSig2Bytes = msg.BitcoinSignature.ToSignatureBytes() } - chanAnn, e1Ann, e2Ann, err := createChanAnnouncement(&dbProof, chanInfo, e1, e2) + chanAnn, e1Ann, e2Ann, err := CreateChanAnnouncement(&dbProof, chanInfo, e1, e2) if err != nil { log.Error(err) nMsg.err <- err @@ -2079,6 +2304,7 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n // that the caller knows that the message will be delivered at one point. func (d *AuthenticatedGossiper) sendAnnSigReliably( msg *lnwire.AnnounceSignatures, remotePeer *btcec.PublicKey) error { + // We first add this message to the database, such that in case // we do not succeed in sending it to the peer, we'll fetch it // from the DB next time we start, and retry. We use the peer ID @@ -2257,3 +2483,36 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo, return chanAnn, chanUpdate, err } + +// maybeRequestChanAnn will attempt to request the full channel announcement +// for a particular short chan ID. We do this in the case that we get a channel +// update, yet don't already have a channel announcement for it. +func (d *AuthenticatedGossiper) maybeRequestChanAnn(cid lnwire.ShortChannelID) error { + d.syncerMtx.Lock() + defer d.syncerMtx.Unlock() + + for nodeID, syncer := range d.peerSyncers { + // If this syncer is already at the terminal state, then we'll + // chose it to request the fully channel update. + if syncer.SyncState() == chansSynced { + pub, err := btcec.ParsePubKey(nodeID[:], btcec.S256()) + if err != nil { + return err + } + + log.Debugf("attempting to request chan ann for "+ + "chan_id=%v from node=%x", cid, nodeID[:]) + + return d.cfg.SendToPeer(pub, &lnwire.QueryShortChanIDs{ + ChainHash: d.cfg.ChainHash, + EncodingType: lnwire.EncodingSortedPlain, + ShortChanIDs: []lnwire.ShortChannelID{cid}, + }) + } + } + + log.Debugf("unable to find peer to request chan ann for chan_id=%v "+ + "from", cid) + + return nil +} diff --git a/discovery/syncer.go b/discovery/syncer.go new file mode 100644 index 00000000..85e356bc --- /dev/null +++ b/discovery/syncer.go @@ -0,0 +1,915 @@ +package discovery + +import ( + "fmt" + "math" + "sync" + "sync/atomic" + "time" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// syncerState is an enum that represents the current state of the +// gossipSyncer. As the syncer is a state machine, we'll gate our actions +// based off of the current state and the next incoming message. +type syncerState uint32 + +const ( + // syncingChans is the default state of the gossipSyncer. We start in + // this state when a new peer first connects and we don't yet know if + // we're fully synchronized. + syncingChans syncerState = iota + + // waitingQueryRangeReply is the second main phase of the gossipSyncer. + // We enter this state after we send out our first QueryChannelRange + // reply. We'll stay in this state until the remote party sends us a + // ReplyShortChanIDsEnd message that indicates they've responded to our + // query entirely. After this state, we'll transition to + // waitingQueryChanReply after we send out requests for all the new + // chan ID's to us. + waitingQueryRangeReply + + // queryNewChannels is the third main phase of the gossipSyncer. In + // this phase we'll send out all of our QueryShortChanIDs messages in + // response to the new channels that we don't yet know about. + queryNewChannels + + // waitingQueryChanReply is the fourth main phase of the gossipSyncer. + // We enter this phase once we've sent off a query chink to the remote + // peer. We'll stay in this phase until we receive a + // ReplyShortChanIDsEnd message which indicates that the remote party + // has responded to all of our requests. + waitingQueryChanReply + + // chansSynced is the terminal stage of the gossipSyncer. Once we enter + // this phase, we'll send out our update horizon, which filters out the + // set of channel updates that we're interested in. In this state, + // we'll be able to accept any outgoing messages from the + // AuthenticatedGossiper, and decide if we should forward them to our + // target peer based on its update horizon. + chansSynced +) + +// String returns a human readable string describing the target syncerState. +func (s syncerState) String() string { + switch s { + case syncingChans: + return "syncingChans" + + case waitingQueryRangeReply: + return "waitingQueryRangeReply" + + case queryNewChannels: + return "queryNewChannels" + + case waitingQueryChanReply: + return "waitingQueryChanReply" + + case chansSynced: + return "chansSynced" + + default: + return "UNKNOWN STATE" + } +} + +var ( + // encodingTypeToChunkSize maps an encoding type, to the max number of + // short chan ID's using the encoding type that we can fit into a + // single message safely. + encodingTypeToChunkSize = map[lnwire.ShortChanIDEncoding]int32{ + lnwire.EncodingSortedPlain: 8000, + } +) + +const ( + // chanRangeQueryBuffer is the number of blocks back that we'll go when + // asking the remote peer for their any channels they know of beyond + // our highest known channel ID. + chanRangeQueryBuffer = 144 +) + +// ChannelGraphTimeSeries is an interface that provides time and block based +// querying into our view of the channel graph. New channels will have +// monotonically increasing block heights, and new channel updates will have +// increasing timestamps. Once we connect to a peer, we'll use the methods in +// this interface to determine if we're already in sync, or need to request +// some new information from them. +type ChannelGraphTimeSeries interface { + // HighestChanID should return the channel ID of the channel we know of + // that's furthest in the target chain. This channel will have a block + // height that's close to the current tip of the main chain as we + // know it. We'll use this to start our QueryChannelRange dance with + // the remote node. + HighestChanID(chain chainhash.Hash) (*lnwire.ShortChannelID, error) + + // UpdatesInHorizon returns all known channel and node updates with an + // update timestamp between the start time and end time. We'll use this + // to catch up a remote node to the set of channel updates that they + // may have missed out on within the target chain. + UpdatesInHorizon(chain chainhash.Hash, + startTime time.Time, endTime time.Time) ([]lnwire.Message, error) + + // FilterKnownChanIDs takes a target chain, and a set of channel ID's, + // and returns a filtered set of chan ID's. This filtered set of chan + // ID's represents the ID's that we don't know of which were in the + // passed superSet. + FilterKnownChanIDs(chain chainhash.Hash, + superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) + + // FilterChannelRange returns the set of channels that we created + // between the start height and the end height. We'll use this to to a + // remote peer's QueryChannelRange message. + FilterChannelRange(chain chainhash.Hash, + startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) + + // FetchChanAnns returns a full set of channel announcements as well as + // their updates that match the set of specified short channel ID's. + // We'll use this to reply to a QueryShortChanIDs message sent by a + // remote peer. The response will contain a unique set of + // ChannelAnnouncements, the latest ChannelUpdate for each of the + // announcements, and a unique set of NodeAnnouncements. + FetchChanAnns(chain chainhash.Hash, + shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) + + // FetchChanUpdates returns the latest channel update messages for the + // specified short channel ID. If no channel updates are known for the + // channel, then an empty slice will be returned. + FetchChanUpdates(chain chainhash.Hash, + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) +} + +// gossipSyncerCfg is a struct that packages all the information a gossipSyncer +// needs to carry out its duties. +type gossipSyncerCfg struct { + // chainHash is the chain that this syncer is responsible for. + chainHash chainhash.Hash + + // syncChanUpdates is a bool that indicates if we should request a + // continual channel update stream or not. + syncChanUpdates bool + + // channelSeries is the primary interface that we'll use to generate + // our queries and respond to the queries of the remote peer. + channelSeries ChannelGraphTimeSeries + + // encodingType is the current encoding type we're aware of. Requests + // with different encoding types will be rejected. + encodingType lnwire.ShortChanIDEncoding + + // sendToPeer is a function closure that should send the set of + // targeted messages to the peer we've been assigned to sync the graph + // state from. + sendToPeer func(...lnwire.Message) error +} + +// gossipSyncer is a struct that handles synchronizing the channel graph state +// with a remote peer. The gossipSyncer implements a state machine that will +// progressively ensure we're synchronized with the channel state of the remote +// node. Once both nodes have been synchronized, we'll use an update filter to +// filter out which messages should be sent to a remote peer based on their +// update horizon. If the update horizon isn't specified, then we won't send +// them any channel updates at all. +// +// TODO(roasbeef): modify to only sync from one peer at a time? +type gossipSyncer struct { + // remoteUpdateHorizon is the update horizon of the remote peer. We'll + // use this to properly filter out any messages. + remoteUpdateHorizon *lnwire.GossipTimestampRange + + // localUpdateHorizon is our local update horizon, we'll use this to + // determine if we've already sent out our update. + localUpdateHorizon *lnwire.GossipTimestampRange + + // state is the current state of the gossipSyncer. + // + // NOTE: This variable MUST be used atomically. + state uint32 + + // gossipMsgs is a channel that all messages from the target peer will + // be sent over. + gossipMsgs chan lnwire.Message + + // bufferedChanRangeReplies is used in the waitingQueryChanReply to + // buffer all the chunked response to our query. + bufferedChanRangeReplies []lnwire.ShortChannelID + + // newChansToQuery is used to pass the set of channels we should query + // for from the waitingQueryChanReply state to the queryNewChannels + // state. + newChansToQuery []lnwire.ShortChannelID + + // peerPub is the public key of the peer we're syncing with, serialized + // in compressed format. + peerPub [33]byte + + cfg gossipSyncerCfg + + sync.Mutex + + quit chan struct{} + wg sync.WaitGroup +} + +// newGossiperSyncer returns a new instance of the gossipSyncer populated using +// the passed config. +func newGossiperSyncer(cfg gossipSyncerCfg) *gossipSyncer { + return &gossipSyncer{ + cfg: cfg, + gossipMsgs: make(chan lnwire.Message, 100), + quit: make(chan struct{}), + } +} + +// Start starts the gossipSyncer and any goroutines that it needs to carry out +// its duties. +func (g *gossipSyncer) Start() error { + log.Debugf("Starting gossipSyncer(%x)", g.peerPub[:]) + + g.wg.Add(1) + go g.channelGraphSyncer() + + return nil +} + +// Stop signals the gossipSyncer for a graceful exit, then waits until it has +// exited. +func (g *gossipSyncer) Stop() error { + close(g.quit) + + g.wg.Wait() + + return nil +} + +// channelGraphSyncer is the main goroutine responsible for ensuring that we +// properly channel graph state with the remote peer, and also that we only +// send them messages which actually pass their defined update horizon. +func (g *gossipSyncer) channelGraphSyncer() { + defer g.wg.Done() + + // TODO(roasbeef): also add ability to force transition back to syncing + // chans + // * needed if we want to sync chan state very few blocks? + + for { + state := atomic.LoadUint32(&g.state) + log.Debugf("gossipSyncer(%x): state=%v", g.peerPub[:], + syncerState(state)) + + switch syncerState(state) { + // When we're in this state, we're trying to synchronize our + // view of the network with the remote peer. We'll kick off + // this sync by asking them for the set of channels they + // understand, as we'll as responding to any other queries by + // them. + case syncingChans: + // If we're in this state, then we'll send the remote + // peer our opening QueryChannelRange message. + queryRangeMsg, err := g.genChanRangeQuery() + if err != nil { + log.Errorf("unable to gen chan range "+ + "query: %v", err) + return + } + + err = g.cfg.sendToPeer(queryRangeMsg) + if err != nil { + log.Errorf("unable to send chan range "+ + "query: %v", err) + return + } + + // With the message sent successfully, we'll transition + // into the next state where we wait for their reply. + atomic.StoreUint32(&g.state, uint32(waitingQueryRangeReply)) + + // In this state, we've sent out our initial channel range + // query and are waiting for the final response from the remote + // peer before we perform a diff to see with channels they know + // of that we don't. + case waitingQueryRangeReply: + // We'll wait to either process a new message from the + // remote party, or exit due to the gossiper exiting, + // or us being signalled to do so. + select { + case msg := <-g.gossipMsgs: + // The remote peer is sending a response to our + // initial query, we'll collate this response, + // and see if it's the final one in the series. + // If so, we can then transition to querying + // for the new channels. + queryReply, ok := msg.(*lnwire.ReplyChannelRange) + if ok { + err := g.processChanRangeReply(queryReply) + if err != nil { + log.Errorf("unable to "+ + "process chan range "+ + "query: %v", err) + return + } + + continue + } + + // Otherwise, it's the remote peer performing a + // query, which we'll attempt to reply to. + err := g.replyPeerQueries(msg) + if err != nil { + log.Errorf("unable to reply to peer "+ + "query: %v", err) + } + + case <-g.quit: + return + } + + // We'll enter this state once we've discovered which channels + // the remote party knows of that we don't yet know of + // ourselves. + case queryNewChannels: + // First, we'll attempt to continue our channel + // synchronization by continuing to send off another + // query chunk. + done, err := g.synchronizeChanIDs() + if err != nil { + log.Errorf("unable to sync chan IDs: %v", err) + } + + // If this wasn't our last query, then we'll need to + // transition to our waiting state. + if !done { + atomic.StoreUint32(&g.state, uint32(waitingQueryChanReply)) + continue + } + + // If we're fully synchronized, then we can transition + // to our terminal state. + atomic.StoreUint32(&g.state, uint32(chansSynced)) + + // In this state, we've just sent off a new query for channels + // that we don't yet know of. We'll remain in this state until + // the remote party signals they've responded to our query in + // totality. + case waitingQueryChanReply: + // Once we've sent off our query, we'll wait for either + // an ending reply, or just another query from the + // remote peer. + select { + case msg := <-g.gossipMsgs: + // If this is the final reply to one of our + // queries, then we'll loop back into our query + // state to send of the remaining query chunks. + _, ok := msg.(*lnwire.ReplyShortChanIDsEnd) + if ok { + atomic.StoreUint32(&g.state, uint32(queryNewChannels)) + continue + } + + // Otherwise, it's the remote peer performing a + // query, which we'll attempt to deploy to. + err := g.replyPeerQueries(msg) + if err != nil { + log.Errorf("unable to reply to peer "+ + "query: %v", err) + } + + case <-g.quit: + return + } + + // This is our final terminal state where we'll only reply to + // any further queries by the remote peer. + case chansSynced: + // If we haven't yet sent out our update horizon, and + // we want to receive real-time channel updates, we'll + // do so now. + if g.localUpdateHorizon == nil && g.cfg.syncChanUpdates { + // TODO(roasbeef): query DB for most recent + // update? + + // We'll give an hours room in our update + // horizon to ensure we don't miss any newer + // items. + updateHorizon := time.Now().Add(-time.Hour * 1) + log.Infof("gossipSyncer(%x): applying "+ + "gossipFilter(start=%v)", g.peerPub[:], + updateHorizon) + + g.localUpdateHorizon = &lnwire.GossipTimestampRange{ + ChainHash: g.cfg.chainHash, + FirstTimestamp: uint32(updateHorizon.Unix()), + TimestampRange: math.MaxUint32, + } + err := g.cfg.sendToPeer(g.localUpdateHorizon) + if err != nil { + log.Errorf("unable to send update "+ + "horizon: %v", err) + } + } + + // With our horizon set, we'll simply reply to any new + // message and exit if needed. + select { + case msg := <-g.gossipMsgs: + err := g.replyPeerQueries(msg) + if err != nil { + log.Errorf("unable to reply to peer "+ + "query: %v", err) + } + + case <-g.quit: + return + } + } + } +} + +// synchronizeChanIDs is called by the channelGraphSyncer when we need to query +// the remote peer for its known set of channel IDs within a particular block +// range. This method will be called continually until the entire range has +// been queried for with a response received. We'll chunk our requests as +// required to ensure they fit into a single message. We may re-renter this +// state in the case that chunking is required. +func (g *gossipSyncer) synchronizeChanIDs() (bool, error) { + // Ensure that we're able to handle queries using the specified chan + // ID. + chunkSize, ok := encodingTypeToChunkSize[g.cfg.encodingType] + if !ok { + return false, fmt.Errorf("unknown encoding type: %v", + g.cfg.encodingType) + } + + // If we're in this state yet there are no more new channels to query + // for, then we'll transition to our final synced state and return true + // to signal that we're fully synchronized. + if len(g.newChansToQuery) == 0 { + log.Infof("gossipSyncer(%x): no more chans to query", + g.peerPub[:]) + return true, nil + } + + // Otherwise, we'll issue our next chunked query to receive replies + // for. + var queryChunk []lnwire.ShortChannelID + + // If the number of channels to query for is less than the chunk size, + // then we can issue a single query. + if int32(len(g.newChansToQuery)) < chunkSize { + queryChunk = g.newChansToQuery + g.newChansToQuery = nil + + } else { + // Otherwise, we'll need to only query for the next chunk. + // We'll slice into our query chunk, then slide down our main + // pointer down by the chunk size. + queryChunk = g.newChansToQuery[:chunkSize] + g.newChansToQuery = g.newChansToQuery[chunkSize:] + } + + log.Infof("gossipSyncer(%x): querying for %v new channels", + g.peerPub[:], len(queryChunk)) + + // With our chunk obtained, we'll send over our next query, then return + // false indicating that we're net yet fully synced. + err := g.cfg.sendToPeer(&lnwire.QueryShortChanIDs{ + ChainHash: g.cfg.chainHash, + EncodingType: lnwire.EncodingSortedPlain, + ShortChanIDs: queryChunk, + }) + + return false, err +} + +// processChanRangeReply is called each time the gossipSyncer receives a new +// reply to the initial range query to discover new channels that it didn't +// previously know of. +func (g *gossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) error { + g.bufferedChanRangeReplies = append( + g.bufferedChanRangeReplies, msg.ShortChanIDs..., + ) + + log.Infof("gossipSyncer(%x): buffering chan range reply of size=%v", + g.peerPub[:], len(msg.ShortChanIDs)) + + // If this isn't the last response, then we can exit as we've already + // buffered the latest portion of the streaming reply. + if msg.Complete == 0 { + return nil + } + + log.Infof("gossipSyncer(%x): filtering through %v chans", g.peerPub[:], + len(g.bufferedChanRangeReplies)) + + // Otherwise, this is the final response, so we'll now check to see + // which channels they know of that we don't. + newChans, err := g.cfg.channelSeries.FilterKnownChanIDs( + g.cfg.chainHash, g.bufferedChanRangeReplies, + ) + if err != nil { + return fmt.Errorf("unable to filter chan ids: %v", err) + } + + // As we've received the entirety of the reply, we no longer need to + // hold on to the set of buffered replies, so we'll let that be garbage + // collected now. + g.bufferedChanRangeReplies = nil + + // If there aren't any channels that we don't know of, then we can + // switch straight to our terminal state. + if len(newChans) == 0 { + log.Infof("gossipSyncer(%x): remote peer has no new chans", + g.peerPub[:]) + + atomic.StoreUint32(&g.state, uint32(chansSynced)) + return nil + } + + // Otherwise, we'll set the set of channels that we need to query for + // the next state, and also transition our state. + g.newChansToQuery = newChans + atomic.StoreUint32(&g.state, uint32(queryNewChannels)) + + log.Infof("gossipSyncer(%x): starting query for %v new chans", + g.peerPub[:], len(newChans)) + + return nil +} + +// genChanRangeQuery generates the initial message we'll send to the remote +// party when we're kicking off the channel graph synchronization upon +// connection. +func (g *gossipSyncer) genChanRangeQuery() (*lnwire.QueryChannelRange, error) { + // First, we'll query our channel graph time series for its highest + // known channel ID. + newestChan, err := g.cfg.channelSeries.HighestChanID(g.cfg.chainHash) + if err != nil { + return nil, err + } + + // Once we have the chan ID of the newest, we'll obtain the block + // height of the channel, then subtract our default horizon to ensure + // we don't miss any channels. By default, we go back 1 day from the + // newest channel. + var startHeight uint32 + switch { + case newestChan.BlockHeight <= chanRangeQueryBuffer: + fallthrough + case newestChan.BlockHeight == 0: + startHeight = 0 + + default: + startHeight = uint32(newestChan.BlockHeight - chanRangeQueryBuffer) + } + + log.Infof("gossipSyncer(%x): requesting new chans from height=%v "+ + "and %v blocks after", g.peerPub[:], startHeight, + math.MaxUint32-startHeight) + + // Finally, we'll craft the channel range query, using our starting + // height, then asking for all known channels to the foreseeable end of + // the main chain. + return &lnwire.QueryChannelRange{ + ChainHash: g.cfg.chainHash, + FirstBlockHeight: startHeight, + NumBlocks: math.MaxUint32 - startHeight, + }, nil +} + +// replyPeerQueries is called in response to any query by the remote peer. +// We'll examine our state and send back our best response. +func (g *gossipSyncer) replyPeerQueries(msg lnwire.Message) error { + switch msg := msg.(type) { + + // In this state, we'll also handle any incoming channel range queries + // from the remote peer as they're trying to sync their state as well. + case *lnwire.QueryChannelRange: + return g.replyChanRangeQuery(msg) + + // If the remote peer skips straight to requesting new channels that + // they don't know of, then we'll ensure that we also handle this case. + case *lnwire.QueryShortChanIDs: + return g.replyShortChanIDs(msg) + + default: + return fmt.Errorf("unknown message: %T", msg) + } +} + +// replyChanRangeQuery will be dispatched in response to a channel range query +// by the remote node. We'll query the channel time series for channels that +// meet the channel range, then chunk our responses to the remote node. We also +// ensure that our final fragment carries the "complete" bit to indicate the +// end of our streaming response. +func (g *gossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) error { + // Using the current set encoding type, we'll determine what our chunk + // size should be. If we can't locate the chunk size, then we'll return + // an error as we can't proceed. + chunkSize, ok := encodingTypeToChunkSize[g.cfg.encodingType] + if !ok { + return fmt.Errorf("unknown encoding type: %v", g.cfg.encodingType) + } + + log.Infof("gossipSyncer(%x): filtering chan range: start_height=%v, "+ + "num_blocks=%v", g.peerPub[:], query.FirstBlockHeight, + query.NumBlocks) + + // Next, we'll consult the time series to obtain the set of known + // channel ID's that match their query. + startBlock := query.FirstBlockHeight + channelRange, err := g.cfg.channelSeries.FilterChannelRange( + query.ChainHash, startBlock, startBlock+query.NumBlocks, + ) + if err != nil { + return err + } + + // TODO(roasbeef): means can't send max uint above? + // * or make internal 64 + + numChannels := int32(len(channelRange)) + numChansSent := int32(0) + for { + // We'll send our this response in a streaming manner, + // chunk-by-chunk. We do this as there's a transport message + // size limit which we'll need to adhere to. + var channelChunk []lnwire.ShortChannelID + + // We know this is the final chunk, if the difference between + // the total number of channels, and the number of channels + // we've sent is less-than-or-equal to the chunk size. + isFinalChunk := (numChannels - numChansSent) <= chunkSize + + // If this is indeed the last chunk, then we'll send the + // remainder of the channels. + if isFinalChunk { + channelChunk = channelRange[numChansSent:] + + log.Infof("gossipSyncer(%x): sending final chan "+ + "range chunk, size=%v", g.peerPub[:], len(channelChunk)) + + } else { + // Otherwise, we'll only send off a fragment exactly + // sized to the proper chunk size. + channelChunk = channelRange[numChansSent : numChansSent+chunkSize] + + log.Infof("gossipSyncer(%x): sending range chunk of "+ + "size=%v", g.peerPub[:], len(channelChunk)) + } + + // With our chunk assembled, we'll now send to the remote peer + // the current chunk. + replyChunk := lnwire.ReplyChannelRange{ + QueryChannelRange: *query, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: channelChunk, + } + if isFinalChunk { + replyChunk.Complete = 1 + } + if err := g.cfg.sendToPeer(&replyChunk); err != nil { + return err + } + + // If this was the final chunk, then we'll exit now as our + // response is now complete. + if isFinalChunk { + return nil + } + + numChansSent += int32(len(channelChunk)) + } +} + +// replyShortChanIDs will be dispatched in response to a query by the remote +// node for information concerning a set of short channel ID's. Our response +// will be sent in a streaming chunked manner to ensure that we remain below +// the current transport level message size. +func (g *gossipSyncer) replyShortChanIDs(query *lnwire.QueryShortChanIDs) error { + // Before responding, we'll check to ensure that the remote peer is + // querying for the same chain that we're on. If not, we'll send back a + // response with a complete value of zero to indicate we're on a + // different chain. + if g.cfg.chainHash != query.ChainHash { + log.Warnf("Remote peer requested QueryShortChanIDs for "+ + "chain=%v, we're on chain=%v", g.cfg.chainHash, + query.ChainHash) + + return g.cfg.sendToPeer(&lnwire.ReplyShortChanIDsEnd{ + ChainHash: query.ChainHash, + Complete: 0, + }) + } + + log.Infof("gossipSyncer(%x): fetching chan anns for %v chans", + g.peerPub[:], len(query.ShortChanIDs)) + + // Now that we know we're on the same chain, we'll query the channel + // time series for the set of messages that we know of which satisfies + // the requirement of being a chan ann, chan update, or a node ann + // related to the set of queried channels. + replyMsgs, err := g.cfg.channelSeries.FetchChanAnns( + query.ChainHash, query.ShortChanIDs, + ) + if err != nil { + return err + } + + // If we didn't find any messages related to those channel ID's, then + // we'll send over a reply marking the end of our response, and exit + // early. + if len(replyMsgs) == 0 { + return g.cfg.sendToPeer(&lnwire.ReplyShortChanIDsEnd{ + ChainHash: query.ChainHash, + Complete: 1, + }) + } + + // Otherwise, we'll send over our set of messages responding to the + // query, with the ending message appended to it. + replyMsgs = append(replyMsgs, &lnwire.ReplyShortChanIDsEnd{ + ChainHash: query.ChainHash, + Complete: 1, + }) + return g.cfg.sendToPeer(replyMsgs...) +} + +// ApplyGossipFilter applies a gossiper filter sent by the remote node to the +// state machine. Once applied, we'll ensure that we don't forward any messages +// to the peer that aren't within the time range of the filter. +func (g *gossipSyncer) ApplyGossipFilter(filter *lnwire.GossipTimestampRange) error { + g.Lock() + + g.remoteUpdateHorizon = filter + + startTime := time.Unix(int64(g.remoteUpdateHorizon.FirstTimestamp), 0) + endTime := startTime.Add( + time.Duration(g.remoteUpdateHorizon.TimestampRange) * time.Second, + ) + + g.Unlock() + + // Now that the remote peer has applied their filter, we'll query the + // database for all the messages that are beyond this filter. + newUpdatestoSend, err := g.cfg.channelSeries.UpdatesInHorizon( + g.cfg.chainHash, startTime, endTime, + ) + if err != nil { + return err + } + + log.Infof("gossipSyncer(%x): applying new update horizon: start=%v, "+ + "end=%v, backlog_size=%v", g.peerPub[:], startTime, endTime, + len(newUpdatestoSend)) + + // If we don't have any to send, then we can return early. + if len(newUpdatestoSend) == 0 { + return nil + } + + // We'll conclude by launching a goroutine to send out any updates. + g.wg.Add(1) + go func() { + defer g.wg.Done() + + if err := g.cfg.sendToPeer(newUpdatestoSend...); err != nil { + log.Errorf("unable to send messages for peer catch "+ + "up: %v", err) + } + }() + + return nil +} + +// FilterGossipMsgs takes a set of gossip messages, and only send it to a peer +// iff the message is within the bounds of their set gossip filter. If the peer +// doesn't have a gossip filter set, then no messages will be forwarded. +func (g *gossipSyncer) FilterGossipMsgs(msgs ...msgWithSenders) { + // If the peer doesn't have an update horizon set, then we won't send + // it any new update messages. + if g.remoteUpdateHorizon == nil { + return + } + + // TODO(roasbeef): need to ensure that peer still online...send msg to + // gossiper on peer termination to signal peer disconnect? + + var err error + + // Before we filter out the messages, we'll construct an index over the + // set of channel announcements and channel updates. This will allow us + // to quickly check if we should forward a chan ann, based on the known + // channel updates for a channel. + chanUpdateIndex := make(map[lnwire.ShortChannelID][]*lnwire.ChannelUpdate) + for _, msg := range msgs { + chanUpdate, ok := msg.msg.(*lnwire.ChannelUpdate) + if !ok { + continue + } + + chanUpdateIndex[chanUpdate.ShortChannelID] = append( + chanUpdateIndex[chanUpdate.ShortChannelID], chanUpdate, + ) + } + + // We'll construct a helper function that we'll us below to determine + // if a given messages passes the gossip msg filter. + g.Lock() + startTime := time.Unix(int64(g.remoteUpdateHorizon.FirstTimestamp), 0) + endTime := startTime.Add( + time.Duration(g.remoteUpdateHorizon.TimestampRange) * time.Second, + ) + g.Unlock() + + passesFilter := func(timeStamp uint32) bool { + t := time.Unix(int64(timeStamp), 0) + return t.After(startTime) && t.Before(endTime) + } + + msgsToSend := make([]lnwire.Message, 0, len(msgs)) + for _, msg := range msgs { + // If the target peer is the peer that sent us this message, + // then we'll exit early as we don't need to filter this + // message. + if _, ok := msg.senders[g.peerPub]; ok { + continue + } + + switch msg := msg.msg.(type) { + + // For each channel announcement message, we'll only send this + // message if the channel updates for the channel are between + // our time range. + case *lnwire.ChannelAnnouncement: + // First, we'll check if the channel updates are in + // this message batch. + chanUpdates, ok := chanUpdateIndex[msg.ShortChannelID] + if !ok { + // If not, we'll attempt to query the database + // to see if we know of the updates. + chanUpdates, err = g.cfg.channelSeries.FetchChanUpdates( + g.cfg.chainHash, msg.ShortChannelID, + ) + if err != nil { + log.Warnf("no channel updates found for "+ + "short_chan_id=%v", + msg.ShortChannelID) + continue + } + } + + for _, chanUpdate := range chanUpdates { + if passesFilter(chanUpdate.Timestamp) { + msgsToSend = append(msgsToSend, msg) + break + } + } + + if len(chanUpdates) == 0 { + msgsToSend = append(msgsToSend, msg) + } + + // For each channel update, we'll only send if it the timestamp + // is between our time range. + case *lnwire.ChannelUpdate: + if passesFilter(msg.Timestamp) { + msgsToSend = append(msgsToSend, msg) + } + + // Similarly, we only send node announcements if the update + // timestamp ifs between our set gossip filter time range. + case *lnwire.NodeAnnouncement: + if passesFilter(msg.Timestamp) { + msgsToSend = append(msgsToSend, msg) + } + } + } + + log.Tracef("gossipSyncer(%x): filtered gossip msgs: set=%v, sent=%v", + g.peerPub[:], len(msgs), len(msgsToSend)) + + if len(msgsToSend) == 0 { + return + } + + g.cfg.sendToPeer(msgsToSend...) +} + +// ProcessQueryMsg is used by outside callers to pass new channel time series +// queries to the internal processing goroutine. +func (g *gossipSyncer) ProcessQueryMsg(msg lnwire.Message) { + select { + case g.gossipMsgs <- msg: + return + case <-g.quit: + return + } +} + +// SyncerState returns the current syncerState of the target gossipSyncer. +func (g *gossipSyncer) SyncState() syncerState { + return syncerState(atomic.LoadUint32(&g.state)) +} diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go new file mode 100644 index 00000000..8aa6415c --- /dev/null +++ b/discovery/syncer_test.go @@ -0,0 +1,1577 @@ +package discovery + +import ( + "fmt" + "math" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/roasbeef/btcd/chaincfg" + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +type horizonQuery struct { + chain chainhash.Hash + start time.Time + end time.Time +} +type filterRangeReq struct { + startHeight, endHeight uint32 +} + +type mockChannelGraphTimeSeries struct { + highestID lnwire.ShortChannelID + + horizonReq chan horizonQuery + horizonResp chan []lnwire.Message + + filterReq chan []lnwire.ShortChannelID + filterResp chan []lnwire.ShortChannelID + + filterRangeReqs chan filterRangeReq + filterRangeResp chan []lnwire.ShortChannelID + + annReq chan []lnwire.ShortChannelID + annResp chan []lnwire.Message + + updateReq chan lnwire.ShortChannelID + updateResp chan []*lnwire.ChannelUpdate +} + +func newMockChannelGraphTimeSeries(hID lnwire.ShortChannelID) *mockChannelGraphTimeSeries { + return &mockChannelGraphTimeSeries{ + highestID: hID, + + horizonReq: make(chan horizonQuery, 1), + horizonResp: make(chan []lnwire.Message, 1), + + filterReq: make(chan []lnwire.ShortChannelID, 1), + filterResp: make(chan []lnwire.ShortChannelID, 1), + + filterRangeReqs: make(chan filterRangeReq, 1), + filterRangeResp: make(chan []lnwire.ShortChannelID, 1), + + annReq: make(chan []lnwire.ShortChannelID, 1), + annResp: make(chan []lnwire.Message, 1), + + updateReq: make(chan lnwire.ShortChannelID, 1), + updateResp: make(chan []*lnwire.ChannelUpdate, 1), + } +} + +func (m *mockChannelGraphTimeSeries) HighestChanID(chain chainhash.Hash) (*lnwire.ShortChannelID, error) { + return &m.highestID, nil +} +func (m *mockChannelGraphTimeSeries) UpdatesInHorizon(chain chainhash.Hash, + startTime time.Time, endTime time.Time) ([]lnwire.Message, error) { + + m.horizonReq <- horizonQuery{ + chain, startTime, endTime, + } + + return <-m.horizonResp, nil +} +func (m *mockChannelGraphTimeSeries) FilterKnownChanIDs(chain chainhash.Hash, + superSet []lnwire.ShortChannelID) ([]lnwire.ShortChannelID, error) { + + m.filterReq <- superSet + + return <-m.filterResp, nil +} +func (m *mockChannelGraphTimeSeries) FilterChannelRange(chain chainhash.Hash, + startHeight, endHeight uint32) ([]lnwire.ShortChannelID, error) { + + m.filterRangeReqs <- filterRangeReq{startHeight, endHeight} + + return <-m.filterRangeResp, nil +} +func (m *mockChannelGraphTimeSeries) FetchChanAnns(chain chainhash.Hash, + shortChanIDs []lnwire.ShortChannelID) ([]lnwire.Message, error) { + + m.annReq <- shortChanIDs + + return <-m.annResp, nil +} +func (m *mockChannelGraphTimeSeries) FetchChanUpdates(chain chainhash.Hash, + shortChanID lnwire.ShortChannelID) ([]*lnwire.ChannelUpdate, error) { + + m.updateReq <- shortChanID + + return <-m.updateResp, nil +} + +var _ ChannelGraphTimeSeries = (*mockChannelGraphTimeSeries)(nil) + +func newTestSyncer(hID lnwire.ShortChannelID) (chan []lnwire.Message, *gossipSyncer, *mockChannelGraphTimeSeries) { + + msgChan := make(chan []lnwire.Message, 20) + cfg := gossipSyncerCfg{ + syncChanUpdates: true, + channelSeries: newMockChannelGraphTimeSeries(hID), + encodingType: lnwire.EncodingSortedPlain, + sendToPeer: func(msgs ...lnwire.Message) error { + msgChan <- msgs + return nil + }, + } + syncer := newGossiperSyncer(cfg) + + return msgChan, syncer, cfg.channelSeries.(*mockChannelGraphTimeSeries) +} + +// TestGossipSyncerFilterGossipMsgsNoHorizon tests that if the remote peer +// doesn't have a horizon set, then we won't send any incoming messages to it. +func TestGossipSyncerFilterGossipMsgsNoHorizon(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, _ := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // With the syncer created, we'll create a set of messages to filter + // through the gossiper to the target peer. + msgs := []msgWithSenders{ + { + msg: &lnwire.NodeAnnouncement{Timestamp: uint32(time.Now().Unix())}, + }, + { + msg: &lnwire.NodeAnnouncement{Timestamp: uint32(time.Now().Unix())}, + }, + } + + // We'll then attempt to filter the set of messages through the target + // peer. + syncer.FilterGossipMsgs(msgs...) + + // As the remote peer doesn't yet have a gossip timestamp set, we + // shouldn't receive any outbound messages. + select { + case msg := <-msgChan: + t.Fatalf("received message but shouldn't have: %v", + spew.Sdump(msg)) + + case <-time.After(time.Millisecond * 10): + } +} + +func unixStamp(a int64) uint32 { + t := time.Unix(a, 0) + return uint32(t.Unix()) +} + +// TestGossipSyncerFilterGossipMsgsAll tests that we're able to properly filter +// out a set of incoming messages based on the set remote update horizon for a +// peer. We tests all messages type, and all time straddling. We'll also send a +// channel ann that already has a channel update on disk. +func TestGossipSyncerFilterGossipMsgsAllInMemory(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // We'll create then apply a remote horizon for the target peer with a + // set of manually selected timestamps. + remoteHorizon := &lnwire.GossipTimestampRange{ + FirstTimestamp: unixStamp(25000), + TimestampRange: uint32(1000), + } + syncer.remoteUpdateHorizon = remoteHorizon + + // With the syncer created, we'll create a set of messages to filter + // through the gossiper to the target peer. Our message will consist of + // one node announcement above the horizon, one below. Additionally, + // we'll include a chan ann with an update below the horizon, one + // with an update timestmap above the horizon, and one without any + // channel updates at all. + msgs := []msgWithSenders{ + { + // Node ann above horizon. + msg: &lnwire.NodeAnnouncement{Timestamp: unixStamp(25001)}, + }, + { + // Node ann below horizon. + msg: &lnwire.NodeAnnouncement{Timestamp: unixStamp(5)}, + }, + { + // Node ann above horizon. + msg: &lnwire.NodeAnnouncement{Timestamp: unixStamp(999999)}, + }, + { + // Ann tuple below horizon. + msg: &lnwire.ChannelAnnouncement{ + ShortChannelID: lnwire.NewShortChanIDFromInt(10), + }, + }, + { + msg: &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(10), + Timestamp: unixStamp(5), + }, + }, + { + // Ann tuple above horizon. + msg: &lnwire.ChannelAnnouncement{ + ShortChannelID: lnwire.NewShortChanIDFromInt(15), + }, + }, + { + msg: &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(15), + Timestamp: unixStamp(25002), + }, + }, + { + // Ann tuple beyond horizon. + msg: &lnwire.ChannelAnnouncement{ + ShortChannelID: lnwire.NewShortChanIDFromInt(20), + }, + }, + { + msg: &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(20), + Timestamp: unixStamp(999999), + }, + }, + { + // Ann w/o an update at all, the update in the DB will + // be below the horizon. + msg: &lnwire.ChannelAnnouncement{ + ShortChannelID: lnwire.NewShortChanIDFromInt(25), + }, + }, + } + + // Before we send off the query, we'll ensure we send the missing + // channel update for that final ann. It will be below the horizon, so + // shouldn't be sent anyway. + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case query := <-chanSeries.updateReq: + + // It should be asking for the chan updates of short + // chan ID 25. + expectedID := lnwire.NewShortChanIDFromInt(25) + if expectedID != query { + t.Fatalf("wrong query id: expected %v, got %v", + expectedID, query) + } + + // If so, then we'll send back the missing update. + chanSeries.updateResp <- []*lnwire.ChannelUpdate{ + { + ShortChannelID: lnwire.NewShortChanIDFromInt(25), + Timestamp: unixStamp(5), + }, + } + } + }() + + // We'll then instruct the gossiper to filter this set of messages. + syncer.FilterGossipMsgs(msgs...) + + // Out of all the messages we sent in, we should only get 2 of them + // back. + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msgs := <-msgChan: + if len(msgs) != 3 { + t.Fatalf("expected 3 messages instead got %v "+ + "messages: %v", len(msgs), spew.Sdump(msgs)) + } + } +} + +// TestGossipSyncerApplyGossipFilter tests that once a gossip filter is applied +// for the remote peer, then we send the peer all known messages which are +// within their desired time horizon. +func TestGossipSyncerApplyGossipFilter(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // We'll apply this gossip horizon for the remote peer. + remoteHorizon := &lnwire.GossipTimestampRange{ + FirstTimestamp: unixStamp(25000), + TimestampRange: uint32(1000), + } + + // Before we apply the horizon, we'll dispatch a response to the query + // that the syncer will issue. + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case query := <-chanSeries.horizonReq: + // The syncer should have translated the time range + // into the proper star time. + if remoteHorizon.FirstTimestamp != uint32(query.start.Unix()) { + t.Fatalf("wrong query stamp: expected %v, got %v", + remoteHorizon.FirstTimestamp, query.start) + } + + // For this first response, we'll send back an empty + // set of messages. As result, we shouldn't send any + // messages. + chanSeries.horizonResp <- []lnwire.Message{} + } + }() + + // We'll now attempt to apply the gossip filter for the remote peer. + err := syncer.ApplyGossipFilter(remoteHorizon) + if err != nil { + t.Fatalf("unable to apply filter: %v", err) + } + + // There should be no messages in the message queue as we didn't send + // the syncer and messages within the horizon. + select { + case msgs := <-msgChan: + t.Fatalf("expected no msgs, instead got %v", spew.Sdump(msgs)) + default: + } + + // If we repeat the process, but give the syncer a set of valid + // messages, then these should be sent to the remote peer. + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case query := <-chanSeries.horizonReq: + // The syncer should have translated the time range + // into the proper star time. + if remoteHorizon.FirstTimestamp != uint32(query.start.Unix()) { + t.Fatalf("wrong query stamp: expected %v, got %v", + remoteHorizon.FirstTimestamp, query.start) + } + + // For this first response, we'll send back a proper + // set of messages that should be echoed back. + chanSeries.horizonResp <- []lnwire.Message{ + &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(25), + Timestamp: unixStamp(5), + }, + } + } + }() + err = syncer.ApplyGossipFilter(remoteHorizon) + if err != nil { + t.Fatalf("unable to apply filter: %v", err) + } + + // We should get back the exact same message. + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msgs := <-msgChan: + if len(msgs) != 1 { + t.Fatalf("wrong messages: expected %v, got %v", + 1, len(msgs)) + } + } +} + +// TestGossipSyncerReplyShortChanIDsWrongChainHash tests that if we get a chan +// ID query for the wrong chain, then we send back only a short ID end with +// complete=0. +func TestGossipSyncerReplyShortChanIDsWrongChainHash(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, _ := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // We'll now ask the syncer to reply to a chan ID query, but for a + // chain that it isn't aware of. + err := syncer.replyShortChanIDs(&lnwire.QueryShortChanIDs{ + ChainHash: *chaincfg.SimNetParams.GenesisHash, + }) + if err != nil { + t.Fatalf("unable to process short chan ID's: %v", err) + } + + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + case msgs := <-msgChan: + + // We should get back exactly one message, that's a + // ReplyShortChanIDsEnd with a matching chain hash, and a + // complete value of zero. + if len(msgs) != 1 { + t.Fatalf("wrong messages: expected %v, got %v", + 1, len(msgs)) + } + + msg, ok := msgs[0].(*lnwire.ReplyShortChanIDsEnd) + if !ok { + t.Fatalf("expected lnwire.ReplyShortChanIDsEnd "+ + "instead got %T", msg) + } + + if msg.ChainHash != *chaincfg.SimNetParams.GenesisHash { + t.Fatalf("wrong chain hash: expected %v, got %v", + msg.ChainHash, chaincfg.SimNetParams.GenesisHash) + } + if msg.Complete != 0 { + t.Fatalf("complete set incorrectly") + } + } +} + +// TestGossipSyncerReplyShortChanIDs tests that in the case of a known chain +// hash for a QueryShortChanIDs, we'll return the set of matching +// announcements, as well as an ending ReplyShortChanIDsEnd message. +func TestGossipSyncerReplyShortChanIDs(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + queryChanIDs := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(1), + lnwire.NewShortChanIDFromInt(2), + lnwire.NewShortChanIDFromInt(3), + } + + queryReply := []lnwire.Message{ + &lnwire.ChannelAnnouncement{ + ShortChannelID: lnwire.NewShortChanIDFromInt(20), + }, + &lnwire.ChannelUpdate{ + ShortChannelID: lnwire.NewShortChanIDFromInt(20), + Timestamp: unixStamp(999999), + }, + &lnwire.NodeAnnouncement{Timestamp: unixStamp(25001)}, + } + + // We'll then craft a reply to the upcoming query for all the matching + // channel announcements for a particular set of short channel ID's. + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case chanIDs := <-chanSeries.annReq: + // The set of chan ID's should match exactly. + if !reflect.DeepEqual(chanIDs, queryChanIDs) { + t.Fatalf("wrong chan IDs: expected %v, got %v", + queryChanIDs, chanIDs) + } + + // If they do, then we'll send back a response with + // some canned messages. + chanSeries.annResp <- queryReply + } + }() + + // With our set up above complete, we'll now attempt to obtain a reply + // from the channel syncer for our target chan ID query. + err := syncer.replyShortChanIDs(&lnwire.QueryShortChanIDs{ + ShortChanIDs: queryChanIDs, + }) + if err != nil { + t.Fatalf("unable to query for chan IDs: %v", err) + } + + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + // We should get back exactly 4 messages. The first 3 are the same + // messages we sent above, and the query end message. + case msgs := <-msgChan: + if len(msgs) != 4 { + t.Fatalf("wrong messages: expected %v, got %v", + 4, len(msgs)) + } + + if !reflect.DeepEqual(queryReply, msgs[:3]) { + t.Fatalf("wrong set of messages: expected %v, got %v", + spew.Sdump(queryReply), spew.Sdump(msgs[:3])) + } + + finalMsg, ok := msgs[3].(*lnwire.ReplyShortChanIDsEnd) + if !ok { + t.Fatalf("expected lnwire.ReplyShortChanIDsEnd "+ + "instead got %T", msgs[3]) + } + if finalMsg.Complete != 1 { + t.Fatalf("complete wasn't set") + } + } +} + +// TestGossipSyncerReplyChanRangeQueryUnknownEncodingType tests that if we +// receive a QueryChannelRange message with an unknown encoding type, then we +// return an error. +func TestGossipSyncerReplyChanRangeQueryUnknownEncodingType(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + _, syncer, _ := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // If we modify the syncer to expect an encoding type that is currently + // unknown, then it should fail to process the message and return an + // error. + syncer.cfg.encodingType = 99 + err := syncer.replyChanRangeQuery(&lnwire.QueryChannelRange{}) + if err == nil { + t.Fatalf("expected message fail") + } +} + +// TestGossipSyncerReplyChanRangeQuery tests that if we receive a +// QueryChannelRange message, then we'll properly send back a chunked reply to +// the remote peer. +func TestGossipSyncerReplyChanRangeQuery(t *testing.T) { + t.Parallel() + + // First, we'll modify the main map to provide e a smaller chunk size + // so we can easily test all the edge cases. + encodingTypeToChunkSize[lnwire.EncodingSortedPlain] = 2 + + // We'll now create our test gossip syncer that will shortly respond to + // our canned query. + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // Next, we'll craft a query to ask for all the new chan ID's after + // block 100. + query := &lnwire.QueryChannelRange{ + FirstBlockHeight: 100, + NumBlocks: 50, + } + + // We'll then launch a goroutine to reply to the query with a set of 5 + // responses. This will ensure we get two full chunks, and one partial + // chunk. + resp := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(1), + lnwire.NewShortChanIDFromInt(2), + lnwire.NewShortChanIDFromInt(3), + lnwire.NewShortChanIDFromInt(4), + lnwire.NewShortChanIDFromInt(5), + } + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case filterReq := <-chanSeries.filterRangeReqs: + // We should be querying for block 100 to 150. + if filterReq.startHeight != 100 && filterReq.endHeight != 150 { + t.Fatalf("wrong height range: %v", spew.Sdump(filterReq)) + } + + // If the proper request was sent, then we'll respond + // with our set of short channel ID's. + chanSeries.filterRangeResp <- resp + } + }() + + // With our goroutine active, we'll now issue the query. + if err := syncer.replyChanRangeQuery(query); err != nil { + t.Fatalf("unable to issue query: %v", err) + } + + // At this point, we'll now wait for the syncer to send the chunked + // reply. We should get three sets of messages as two of them should be + // full, while the other is the final fragment. + const numExpectedChunks = 3 + respMsgs := make([]lnwire.ShortChannelID, 0, 5) + for i := 0; i < 3; i++ { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msg := <-msgChan: + resp := msg[0] + rangeResp, ok := resp.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("expected ReplyChannelRange instead got %T", msg) + } + + // If this is not the last chunk, then Complete should + // be set to zero. Otherwise, it should be one. + switch { + case i < 2 && rangeResp.Complete != 0: + t.Fatalf("non-final chunk should have "+ + "Complete=0: %v", spew.Sdump(rangeResp)) + + case i == 2 && rangeResp.Complete != 1: + t.Fatalf("final chunk should have "+ + "Complete=1: %v", spew.Sdump(rangeResp)) + } + + respMsgs = append(respMsgs, rangeResp.ShortChanIDs...) + } + } + + // We should get back exactly 5 short chan ID's, and they should match + // exactly the ID's we sent as a reply. + if len(respMsgs) != len(resp) { + t.Fatalf("expected %v chan ID's, instead got %v", + len(resp), spew.Sdump(respMsgs)) + } + if !reflect.DeepEqual(resp, respMsgs) { + t.Fatalf("mismatched response: expected %v, got %v", + spew.Sdump(resp), spew.Sdump(respMsgs)) + } +} + +// TestGossipSyncerReplyChanRangeQueryNoNewChans tests that if we issue a reply +// for a channel range query, and we don't have any new channels, then we send +// back a single response that signals completion. +func TestGossipSyncerReplyChanRangeQueryNoNewChans(t *testing.T) { + t.Parallel() + + // We'll now create our test gossip syncer that will shortly respond to + // our canned query. + msgChan, syncer, chanSeries := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // Next, we'll craft a query to ask for all the new chan ID's after + // block 100. + query := &lnwire.QueryChannelRange{ + FirstBlockHeight: 100, + NumBlocks: 50, + } + + // We'll then launch a goroutine to reply to the query no new channels. + resp := []lnwire.ShortChannelID{} + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case filterReq := <-chanSeries.filterRangeReqs: + // We should be querying for block 100 to 150. + if filterReq.startHeight != 100 && filterReq.endHeight != 150 { + t.Fatalf("wrong height range: %v", + spew.Sdump(filterReq)) + } + + // If the proper request was sent, then we'll respond + // with our blank set of short chan ID's. + chanSeries.filterRangeResp <- resp + } + }() + + // With our goroutine active, we'll now issue the query. + if err := syncer.replyChanRangeQuery(query); err != nil { + t.Fatalf("unable to issue query: %v", err) + } + + // We should get back exactly one message, and the message should + // indicate that this is the final in the series. + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msg := <-msgChan: + resp := msg[0] + rangeResp, ok := resp.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("expected ReplyChannelRange instead got %T", msg) + } + + if len(rangeResp.ShortChanIDs) != 0 { + t.Fatalf("expected no chan ID's, instead "+ + "got: %v", spew.Sdump(rangeResp.ShortChanIDs)) + } + if rangeResp.Complete != 1 { + t.Fatalf("complete wasn't set") + } + } +} + +// TestGossipSyncerGenChanRangeQuery tests that given the current best known +// channel ID, we properly generate an correct initial channel range response. +func TestGossipSyncerGenChanRangeQuery(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + const startingHeight = 200 + _, syncer, _ := newTestSyncer( + lnwire.ShortChannelID{ + BlockHeight: startingHeight, + }, + ) + + // If we now ask the syncer to generate an initial range query, it + // should return a start height that's back chanRangeQueryBuffer + // blocks. + rangeQuery, err := syncer.genChanRangeQuery() + if err != nil { + t.Fatalf("unable to resp: %v", err) + } + + firstHeight := uint32(startingHeight - chanRangeQueryBuffer) + if rangeQuery.FirstBlockHeight != firstHeight { + t.Fatalf("incorrect chan range query: expected %v, %v", + rangeQuery.FirstBlockHeight, + startingHeight-chanRangeQueryBuffer) + } + if rangeQuery.NumBlocks != math.MaxUint32-firstHeight { + t.Fatalf("wrong num blocks: expected %v, got %v", + rangeQuery.NumBlocks, math.MaxUint32-firstHeight) + } +} + +// TestGossipSyncerProcessChanRangeReply tests that we'll properly buffer +// replied channel replies until we have the complete version. If no new +// channels were discovered, then we should go directly to the chanSsSynced +// state. Otherwise, we should go to the queryNewChannels states. +func TestGossipSyncerProcessChanRangeReply(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + _, syncer, chanSeries := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + startingState := syncer.state + + replies := []*lnwire.ReplyChannelRange{ + { + ShortChanIDs: []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(10), + }, + }, + { + ShortChanIDs: []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(11), + }, + }, + { + Complete: 1, + ShortChanIDs: []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(12), + }, + }, + } + + // We'll begin by sending the syncer a set of non-complete channel + // range replies. + if err := syncer.processChanRangeReply(replies[0]); err != nil { + t.Fatalf("unable to process reply: %v", err) + } + if err := syncer.processChanRangeReply(replies[1]); err != nil { + t.Fatalf("unable to process reply: %v", err) + } + + // At this point, we should still be in our starting state as the query + // hasn't finished. + if syncer.state != startingState { + t.Fatalf("state should not have transitioned") + } + + expectedReq := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(10), + lnwire.NewShortChanIDFromInt(11), + lnwire.NewShortChanIDFromInt(12), + } + + // As we're about to send the final response, we'll launch a goroutine + // to respond back with a filtered set of chan ID's. + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case req := <-chanSeries.filterReq: + // We should get a request for the entire range of short + // chan ID's. + if !reflect.DeepEqual(expectedReq, req) { + fmt.Printf("wrong request: expected %v, got %v\n", + expectedReq, req) + + t.Fatalf("wrong request: expected %v, got %v", + expectedReq, req) + } + + // We'll send back only the last two to simulate filtering. + chanSeries.filterResp <- expectedReq[1:] + } + }() + + // If we send the final message, then we should transition to + // queryNewChannels as we've sent a non-empty set of new channels. + if err := syncer.processChanRangeReply(replies[2]); err != nil { + t.Fatalf("unable to process reply: %v", err) + } + + if syncer.SyncState() != queryNewChannels { + t.Fatalf("wrong state: expected %v instead got %v", + queryNewChannels, syncer.state) + } + if !reflect.DeepEqual(syncer.newChansToQuery, expectedReq[1:]) { + t.Fatalf("wrong set of chans to query: expected %v, got %v", + syncer.newChansToQuery, expectedReq[1:]) + } + + // We'll repeat our final reply again, but this time we won't send any + // new channels. As a result, we should transition over to the + // chansSynced state. + go func() { + select { + case <-time.After(time.Second * 15): + t.Fatalf("no query recvd") + + case req := <-chanSeries.filterReq: + // We should get a request for the entire range of short + // chan ID's. + if !reflect.DeepEqual(expectedReq[2], req[0]) { + t.Fatalf("wrong request: expected %v, got %v", + expectedReq[2], req[0]) + } + + // We'll send back only the last two to simulate filtering. + chanSeries.filterResp <- []lnwire.ShortChannelID{} + } + }() + if err := syncer.processChanRangeReply(replies[2]); err != nil { + t.Fatalf("unable to process reply: %v", err) + } + + if syncer.SyncState() != chansSynced { + t.Fatalf("wrong state: expected %v instead got %v", + chansSynced, syncer.state) + } +} + +// TestGossipSyncerSynchronizeChanIDsUnknownEncodingType tests that if we +// attempt to query for a set of new channels using an unknown encoding type, +// then we'll get an error. +func TestGossipSyncerSynchronizeChanIDsUnknownEncodingType(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + _, syncer, _ := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // If we modify the syncer to expect an encoding type that is currently + // unknown, then it should fail to process the message and return an + // error. + syncer.cfg.encodingType = 101 + _, err := syncer.synchronizeChanIDs() + if err == nil { + t.Fatalf("expected message fail") + } +} + +// TestGossipSyncerSynchronizeChanIDs tests that we properly request chunks of +// the short chan ID's which were unknown to us. We'll ensure that we request +// chunk by chunk, and after the last chunk, we return true indicating that we +// can transition to the synced stage. +func TestGossipSyncerSynchronizeChanIDs(t *testing.T) { + t.Parallel() + + // First, we'll create a gossipSyncer instance with a canned sendToPeer + // message to allow us to intercept their potential sends. + msgChan, syncer, _ := newTestSyncer( + lnwire.NewShortChanIDFromInt(10), + ) + + // Next, we'll construct a set of chan ID's that we should query for, + // and set them as newChansToQuery within the state machine. + newChanIDs := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(1), + lnwire.NewShortChanIDFromInt(2), + lnwire.NewShortChanIDFromInt(3), + lnwire.NewShortChanIDFromInt(4), + lnwire.NewShortChanIDFromInt(5), + } + syncer.newChansToQuery = newChanIDs + + // We'll modify the chunk size to be a smaller value, so we can ensure + // our chunk parsing works properly. With this value we should get 3 + // queries: two full chunks, and one lingering chunk. + chunkSize := int32(2) + encodingTypeToChunkSize[lnwire.EncodingSortedPlain] = chunkSize + + for i := int32(0); i < chunkSize*2; i += 2 { + // With our set up complete, we'll request a sync of chan ID's. + done, err := syncer.synchronizeChanIDs() + if err != nil { + t.Fatalf("unable to sync chan IDs: %v", err) + } + + // At this point, we shouldn't yet be done as only 2 items + // should have been queried for. + if done { + t.Fatalf("syncer shown as done, but shouldn't be!") + } + + // We should've received a new message from the syncer. + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msg := <-msgChan: + queryMsg, ok := msg[0].(*lnwire.QueryShortChanIDs) + if !ok { + t.Fatalf("expected QueryShortChanIDs instead "+ + "got %T", msg) + } + + // The query message should have queried for the first + // two chan ID's, and nothing more. + if !reflect.DeepEqual(queryMsg.ShortChanIDs, newChanIDs[i:i+chunkSize]) { + t.Fatalf("wrong query: expected %v, got %v", + spew.Sdump(newChanIDs[i:i+chunkSize]), + queryMsg.ShortChanIDs) + } + } + + // With the proper message sent out, the internal state of the + // syncer should reflect that it still has more channels to + // query for. + if !reflect.DeepEqual(syncer.newChansToQuery, newChanIDs[i+chunkSize:]) { + t.Fatalf("incorrect chans to query for: expected %v, got %v", + spew.Sdump(newChanIDs[i+chunkSize:]), + syncer.newChansToQuery) + } + } + + // At this point, only one more channel should be lingering for the + // syncer to query for. + if !reflect.DeepEqual(newChanIDs[chunkSize*2:], syncer.newChansToQuery) { + t.Fatalf("wrong chans to query: expected %v, got %v", + newChanIDs[chunkSize*2:], syncer.newChansToQuery) + } + + // If we issue another query, the syncer should tell us that it's done. + done, err := syncer.synchronizeChanIDs() + if err != nil { + t.Fatalf("unable to sync chan IDs: %v", err) + } + if done { + t.Fatalf("syncer should be finished!") + } + + select { + case <-time.After(time.Second * 15): + t.Fatalf("no msgs received") + + case msg := <-msgChan: + queryMsg, ok := msg[0].(*lnwire.QueryShortChanIDs) + if !ok { + t.Fatalf("expected QueryShortChanIDs instead "+ + "got %T", msg) + } + + // The query issued should simply be the last item. + if !reflect.DeepEqual(queryMsg.ShortChanIDs, newChanIDs[chunkSize*2:]) { + t.Fatalf("wrong query: expected %v, got %v", + spew.Sdump(newChanIDs[chunkSize*2:]), + queryMsg.ShortChanIDs) + } + + // There also should be no more channels to query. + if len(syncer.newChansToQuery) != 0 { + t.Fatalf("should be no more chans to query for, "+ + "instead have %v", + spew.Sdump(syncer.newChansToQuery)) + } + } +} + +// TestGossipSyncerRoutineSync tests all state transitions of the main syncer +// goroutine. This ensures that given an encounter with a peer that has a set +// of distinct channels, then we'll properly synchronize our channel state with +// them. +func TestGossipSyncerRoutineSync(t *testing.T) { + t.Parallel() + + // First, we'll create two gossipSyncer instances with a canned + // sendToPeer message to allow us to intercept their potential sends. + startHeight := lnwire.ShortChannelID{ + BlockHeight: 1144, + } + msgChan1, syncer1, chanSeries1 := newTestSyncer( + startHeight, + ) + syncer1.Start() + defer syncer1.Stop() + + msgChan2, syncer2, chanSeries2 := newTestSyncer( + startHeight, + ) + syncer2.Start() + defer syncer2.Stop() + + // Although both nodes are at the same height, they'll have a + // completely disjoint set of 3 chan ID's that they know of. + syncer1Chans := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(1), + lnwire.NewShortChanIDFromInt(2), + lnwire.NewShortChanIDFromInt(3), + } + syncer2Chans := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(4), + lnwire.NewShortChanIDFromInt(5), + lnwire.NewShortChanIDFromInt(6), + } + + // Before we start the test, we'll set our chunk size to 2 in order to + // make testing the chunked requests and replies easier. + chunkSize := int32(2) + encodingTypeToChunkSize[lnwire.EncodingSortedPlain] = chunkSize + + // We'll kick off the test by passing over the QueryChannelRange + // messages from one node to the other. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a QueryChannelRange message. + _, ok := msg.(*lnwire.QueryChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a QueryChannelRange message. + _, ok := msg.(*lnwire.QueryChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + + // At this point, we'll need to send responses to both nodes from their + // respective channel series. Both nodes will simply request the entire + // set of channels from the other. + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.filterRangeReqs: + // We'll send all the channels that it should know of. + chanSeries1.filterRangeResp <- syncer1Chans + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.filterRangeReqs: + // We'll send back all the channels that it should know of. + chanSeries2.filterRangeResp <- syncer2Chans + } + + // At this point, we'll forward the ReplyChannelRange messages to both + // parties. Two replies are expected since the chunk size is 2, and we + // need to query for 3 channels. + for i := 0; i < 2; i++ { + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a ReplyChannelRange message. + _, ok := msg.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + } + } + } + } + for i := 0; i < 2; i++ { + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a ReplyChannelRange message. + _, ok := msg.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + } + } + } + } + + // We'll now send back a chunked response for both parties of the known + // short chan ID's. + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.filterReq: + chanSeries1.filterResp <- syncer2Chans + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.filterReq: + chanSeries2.filterResp <- syncer1Chans + } + + // At this point, both parties should start to send out initial + // requests to query the chan IDs of the remote party. As the chunk + // size is 3, they'll need 2 rounds in order to fully reconcile the + // state. + for i := 0; i < 2; i++ { + // Both parties should now have sent out the initial requests + // to query the chan IDs of the other party. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a QueryShortChanIDs message. + _, ok := msg.(*lnwire.QueryShortChanIDs) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryShortChanIDs for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a QueryShortChanIDs message. + _, ok := msg.(*lnwire.QueryShortChanIDs) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryShortChanIDs for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + + // We'll then respond to both parties with an empty set of replies (as + // it doesn't affect the test). + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.annReq: + chanSeries1.annResp <- []lnwire.Message{} + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.annReq: + chanSeries2.annResp <- []lnwire.Message{} + } + + // Both sides should then receive a ReplyShortChanIDsEnd as the first + // chunk has been replied to. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a ReplyShortChanIDsEnd message. + _, ok := msg.(*lnwire.ReplyShortChanIDsEnd) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a ReplyShortChanIDsEnd message. + _, ok := msg.(*lnwire.ReplyShortChanIDsEnd) + if !ok { + t.Fatalf("wrong message: expected "+ + "ReplyShortChanIDsEnd for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + } + + // At this stage both parties should now be sending over their initial + // GossipTimestampRange messages as they should both be fully synced. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a GossipTimestampRange message. + _, ok := msg.(*lnwire.GossipTimestampRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a GossipTimestampRange message. + _, ok := msg.(*lnwire.GossipTimestampRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } +} + +// TestGossipSyncerAlreadySynced tests that if we attempt to synchronize two +// syncers that have the exact same state, then they'll skip straight to the +// final state and not perform any channel queries. +func TestGossipSyncerAlreadySynced(t *testing.T) { + t.Parallel() + + // First, we'll create two gossipSyncer instances with a canned + // sendToPeer message to allow us to intercept their potential sends. + startHeight := lnwire.ShortChannelID{ + BlockHeight: 1144, + } + msgChan1, syncer1, chanSeries1 := newTestSyncer( + startHeight, + ) + syncer1.Start() + defer syncer1.Stop() + + msgChan2, syncer2, chanSeries2 := newTestSyncer( + startHeight, + ) + syncer2.Start() + defer syncer2.Stop() + + // Before we start the test, we'll set our chunk size to 2 in order to + // make testing the chunked requests and replies easier. + chunkSize := int32(2) + encodingTypeToChunkSize[lnwire.EncodingSortedPlain] = chunkSize + + // The channel state of both syncers will be identical. They should + // recognize this, and skip the sync phase below. + syncer1Chans := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(1), + lnwire.NewShortChanIDFromInt(2), + lnwire.NewShortChanIDFromInt(3), + } + syncer2Chans := []lnwire.ShortChannelID{ + lnwire.NewShortChanIDFromInt(1), + lnwire.NewShortChanIDFromInt(2), + lnwire.NewShortChanIDFromInt(3), + } + + // We'll now kick off the test by allowing both side to send their + // QueryChannelRange messages to each other. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a QueryChannelRange message. + _, ok := msg.(*lnwire.QueryChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a QueryChannelRange message. + _, ok := msg.(*lnwire.QueryChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } + + // We'll now send back the range each side should send over: the set of + // channels they already know about. + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.filterRangeReqs: + // We'll send all the channels that it should know of. + chanSeries1.filterRangeResp <- syncer1Chans + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.filterRangeReqs: + // We'll send back all the channels that it should know of. + chanSeries2.filterRangeResp <- syncer2Chans + } + + // Next, we'll thread through the replies of both parties. As the chunk + // size is 2, and they both know of 3 channels, it'll take two around + // and two chunks. + for i := 0; i < 2; i++ { + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a ReplyChannelRange message. + _, ok := msg.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + } + } + } + } + for i := 0; i < 2; i++ { + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer2") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a ReplyChannelRange message. + _, ok := msg.(*lnwire.ReplyChannelRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + } + } + } + } + + // Now that both sides have the full responses, we'll send over the + // channels that they need to filter out. As both sides have the exact + // same set of channels, they should skip to the final state. + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries1.filterReq: + chanSeries1.filterResp <- []lnwire.ShortChannelID{} + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("no query recvd") + + case <-chanSeries2.filterReq: + chanSeries2.filterResp <- []lnwire.ShortChannelID{} + } + + // As both parties are already synced, the next message they send to + // each other should be the GossipTimestampRange message. + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan1: + for _, msg := range msgs { + // The message MUST be a GossipTimestampRange message. + _, ok := msg.(*lnwire.GossipTimestampRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer2.gossipMsgs <- msg: + + } + } + } + select { + case <-time.After(time.Second * 2): + t.Fatalf("didn't get msg from syncer1") + + case msgs := <-msgChan2: + for _, msg := range msgs { + // The message MUST be a GossipTimestampRange message. + _, ok := msg.(*lnwire.GossipTimestampRange) + if !ok { + t.Fatalf("wrong message: expected "+ + "QueryChannelRange for %T", msg) + } + + select { + case <-time.After(time.Second * 2): + t.Fatalf("node 2 didn't read msg") + + case syncer1.gossipMsgs <- msg: + + } + } + } +} diff --git a/discovery/utils.go b/discovery/utils.go index c9f85a8e..9f38e860 100644 --- a/discovery/utils.go +++ b/discovery/utils.go @@ -8,12 +8,12 @@ import ( "github.com/roasbeef/btcd/btcec" ) -// createChanAnnouncement is a helper function which creates all channel +// CreateChanAnnouncement is a helper function which creates all channel // announcements given the necessary channel related database items. This // function is used to transform out database structs into the corresponding wire // structs for announcing new channels to other peers, or simply syncing up a // peer's initial routing table upon connect. -func createChanAnnouncement(chanProof *channeldb.ChannelAuthProof, +func CreateChanAnnouncement(chanProof *channeldb.ChannelAuthProof, chanInfo *channeldb.ChannelEdgeInfo, e1, e2 *channeldb.ChannelEdgePolicy) (*lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate, *lnwire.ChannelUpdate, error) { diff --git a/lnwire/features.go b/lnwire/features.go index 0449bf53..ee663f90 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -16,11 +16,36 @@ import ( type FeatureBit uint16 const ( + // DataLossProtectRequired is a feature bit that indicates that a peer + // *requires* the other party know about the data-loss-protect optional + // feature. If the remote peer does not know of such a feature, then + // the sending peer SHOLUD disconnect them. The data-loss-protect + // feature allows a peer that's lost partial data to recover their + // settled funds of the latest commitment state. + DataLossProtectRequired FeatureBit = 0 + + // DataLossProtectOptional is an optional feature bit that indicates + // that the sending peer knows of this new feature and can activate it + // it. The data-loss-protect feature allows a peer that's lost partial + // data to recover their settled funds of the latest commitment state. + DataLossProtectOptional FeatureBit = 1 + // InitialRoutingSync is a local feature bit meaning that the receiving // node should send a complete dump of routing information when a new // connection is established. InitialRoutingSync FeatureBit = 3 + // GossipQueriesRequired is a feature bit that indicates that the + // receiving peer MUST know of the set of features that allows nodes to + // more efficiently query the network view of peers on the network for + // reconciliation purposes. + GossipQueriesRequired FeatureBit = 6 + + // GossipQueriesOptional is an optional feature bit that signals that + // the setting peer knows of the set of features that allows more + // efficient network view reconciliation. + GossipQueriesOptional FeatureBit = 7 + // maxAllowedSize is a maximum allowed size of feature vector. // // NOTE: Within the protocol, the maximum allowed message size is 65535 @@ -42,7 +67,9 @@ const ( // not advertised to the entire network. A full description of these feature // bits is provided in the BOLT-09 specification. var LocalFeatures = map[FeatureBit]string{ - InitialRoutingSync: "initial-routing-sync", + DataLossProtectOptional: "data-loss-protect-optional", + InitialRoutingSync: "initial-routing-sync", + GossipQueriesOptional: "gossip-queries-optional", } // GlobalFeatures is a mapping of known global feature bits to a descriptive diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go new file mode 100644 index 00000000..a2180fbc --- /dev/null +++ b/lnwire/gossip_timestamp_range.go @@ -0,0 +1,80 @@ +package lnwire + +import ( + "io" + + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// GossipTimestampRange is a message that allows the sender to restrict the set +// of future gossip announcements sent by the receiver. Nodes should send this +// if they have the gossip-queries feature bit active. Nodes are able to send +// new GossipTimestampRange messages to replace the prior window. +type GossipTimestampRange struct { + // ChainHash denotes the chain that the sender wishes to restrict the + // set of received announcements of. + ChainHash chainhash.Hash + + // FirstTimestamp is the timestamp of the earliest announcement message + // that should be sent by the receiver. + FirstTimestamp uint32 + + // TimestampRange is the horizon beyond the FirstTimestamp that any + // announcement messages should be sent for. The receiving node MUST + // NOT send any announcements that have a timestamp greater than + // FirstTimestamp + TimestampRange. + TimestampRange uint32 +} + +// NewGossipTimestampRange creates a new empty GossipTimestampRange message. +func NewGossipTimestampRange() *GossipTimestampRange { + return &GossipTimestampRange{} +} + +// A compile time check to ensure GossipTimestampRange implements the +// lnwire.Message interface. +var _ Message = (*GossipTimestampRange)(nil) + +// Decode deserializes a serialized GossipTimestampRange message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { + return readElements(r, + g.ChainHash[:], + &g.FirstTimestamp, + &g.TimestampRange, + ) +} + +// Encode serializes the target GossipTimestampRange into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { + return writeElements(w, + g.ChainHash[:], + g.FirstTimestamp, + g.TimestampRange, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) MsgType() MessageType { + return MsgGossipTimestampRange +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// GossipTimestampRange complete message observing the specified protocol +// version. +// +// This is part of the lnwire.Message interface. +func (c *GossipTimestampRange) MaxPayloadLength(uint32) uint32 { + // 32 + 4 + 4 + // + // TODO(roasbeef): update to 8 byte timestmaps? + return 40 +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 212ead8d..39209a5c 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -78,9 +78,14 @@ func (a addressType) AddrLen() uint16 { // // TODO(roasbeef): this should eventually draw from a buffer pool for // serialization. -// TODO(roasbeef): switch to var-ints for all? func writeElement(w io.Writer, element interface{}) error { switch e := element.(type) { + case ShortChanIDEncoding: + var b [1]byte + b[0] = uint8(e) + if _, err := w.Write(b[:]); err != nil { + return err + } case uint8: var b [1]byte b[0] = e @@ -390,6 +395,12 @@ func writeElements(w io.Writer, elements ...interface{}) error { func readElement(r io.Reader, element interface{}) error { var err error switch e := element.(type) { + case *ShortChanIDEncoding: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = ShortChanIDEncoding(b[0]) case *uint8: var b [1]uint8 if _, err := r.Read(b[:]); err != nil { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e8cd720c..fd214308 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -11,6 +11,7 @@ import ( "reflect" "testing" "testing/quick" + "time" "github.com/davecgh/go-spew/spew" "github.com/roasbeef/btcd/btcec" @@ -553,6 +554,57 @@ func TestLightningWireProtocol(t *testing.T) { } } + v[0] = reflect.ValueOf(req) + }, + MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { + req := QueryShortChanIDs{ + // TODO(roasbeef): later alternate encoding types + EncodingType: EncodingSortedPlain, + } + + if _, err := rand.Read(req.ChainHash[:]); err != nil { + t.Fatalf("unable to read chain hash: %v", err) + return + } + + numChanIDs := rand.Int31n(5000) + + req.ShortChanIDs = make([]ShortChannelID, numChanIDs) + for i := int32(0); i < numChanIDs; i++ { + req.ShortChanIDs[i] = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + } + + v[0] = reflect.ValueOf(req) + }, + MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { + req := ReplyChannelRange{ + QueryChannelRange: QueryChannelRange{ + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), + }, + } + + if _, err := rand.Read(req.ChainHash[:]); err != nil { + t.Fatalf("unable to read chain hash: %v", err) + return + } + + req.Complete = uint8(r.Int31n(2)) + + // TODO(roasbeef): later alternate encoding types + req.EncodingType = EncodingSortedPlain + + numChanIDs := rand.Int31n(5000) + + req.ShortChanIDs = make([]ShortChannelID, numChanIDs) + for i := int32(0); i < numChanIDs; i++ { + req.ShortChanIDs[i] = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + } + v[0] = reflect.ValueOf(req) }, } @@ -705,6 +757,36 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgGossipTimestampRange, + scenario: func(m GossipTimestampRange) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgQueryShortChanIDs, + scenario: func(m QueryShortChanIDs) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgReplyShortChanIDsEnd, + scenario: func(m ReplyShortChanIDsEnd) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgQueryChannelRange, + scenario: func(m QueryChannelRange) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgReplyChannelRange, + scenario: func(m ReplyChannelRange) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config @@ -726,3 +808,7 @@ func TestLightningWireProtocol(t *testing.T) { } } + +func init() { + rand.Seed(time.Now().Unix()) +} diff --git a/lnwire/message.go b/lnwire/message.go index a10bba29..b5c27339 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -49,6 +49,11 @@ const ( MsgNodeAnnouncement = 257 MsgChannelUpdate = 258 MsgAnnounceSignatures = 259 + MsgQueryShortChanIDs = 261 + MsgReplyShortChanIDsEnd = 262 + MsgQueryChannelRange = 263 + MsgReplyChannelRange = 264 + MsgGossipTimestampRange = 265 ) // String return the string representation of message type. @@ -100,6 +105,16 @@ func (t MessageType) String() string { return "Pong" case MsgUpdateFee: return "UpdateFee" + case MsgQueryShortChanIDs: + return "QueryShortChanIDs" + case MsgReplyShortChanIDsEnd: + return "ReplyShortChanIDsEnd" + case MsgQueryChannelRange: + return "QueryChannelRange" + case MsgReplyChannelRange: + return "ReplyChannelRange" + case MsgGossipTimestampRange: + return "GossipTimestampRange" default: return "" } @@ -191,6 +206,16 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &AnnounceSignatures{} case MsgPong: msg = &Pong{} + case MsgQueryShortChanIDs: + msg = &QueryShortChanIDs{} + case MsgReplyShortChanIDsEnd: + msg = &ReplyShortChanIDsEnd{} + case MsgQueryChannelRange: + msg = &QueryChannelRange{} + case MsgReplyChannelRange: + msg = &ReplyChannelRange{} + case MsgGossipTimestampRange: + msg = &GossipTimestampRange{} default: return nil, &UnknownMessage{msgType} } diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go new file mode 100644 index 00000000..49b1b4f0 --- /dev/null +++ b/lnwire/query_channel_range.go @@ -0,0 +1,77 @@ +package lnwire + +import ( + "io" + + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// QueryChannelRange is a message sent by a node in order to query the +// receiving node of the set of open channel they know of with short channel +// ID's after the specified block height, capped at the number of blocks beyond +// that block height. This will be used by nodes upon initial connect to +// synchronize their views of the network. +type QueryChannelRange struct { + // ChainHash denotes the target chain that we're trying to synchronize + // channel graph state for. + ChainHash chainhash.Hash + + // FirstBlockHeight is the first block in the query range. The + // responder should send all new short channel IDs from this block + // until this block plus the specified number of blocks. + FirstBlockHeight uint32 + + // NumBlocks is the number of blocks beyond the first block that short + // channel ID's should be sent for. + NumBlocks uint32 +} + +// NewQueryChannelRange creates a new empty QueryChannelRange message. +func NewQueryChannelRange() *QueryChannelRange { + return &QueryChannelRange{} +} + +// A compile time check to ensure QueryChannelRange implements the +// lnwire.Message interface. +var _ Message = (*QueryChannelRange)(nil) + +// Decode deserializes a serialized QueryChannelRange message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { + return readElements(r, + q.ChainHash[:], + &q.FirstBlockHeight, + &q.NumBlocks, + ) +} + +// Encode serializes the target QueryChannelRange into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { + return writeElements(w, + q.ChainHash[:], + q.FirstBlockHeight, + q.NumBlocks, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) MsgType() MessageType { + return MsgQueryChannelRange +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// QueryChannelRange complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) MaxPayloadLength(uint32) uint32 { + // 32 + 4 + 4 + return 40 +} diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go new file mode 100644 index 00000000..4dab6f4c --- /dev/null +++ b/lnwire/query_short_chan_ids.go @@ -0,0 +1,233 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + "sort" + + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// ShortChanIDEncoding is an enum-like type that represents exactly how a set +// of short channel ID's is encoded on the wire. The set of encodings allows us +// to take advantage of the structure of a list of short channel ID's to +// achieving a high degree of compression. +type ShortChanIDEncoding uint8 + +const ( + // EncodingSortedPlain signals that the set of short channel ID's is + // encoded using the regular encoding, in a sorted order. + EncodingSortedPlain ShortChanIDEncoding = 0 + + // TODO(roasbeef): list max number of short chan id's that are able to + // use +) + +// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we +// came across an unknown short channel ID encoding, and therefore were unable +// to continue parsing. +func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { + return fmt.Errorf("unknown short chan id encoding: %v", encoding) +} + +// QueryShortChanIDs is a message that allows the sender to query a set of +// channel announcement and channel update messages that correspond to the set +// of encoded short channel ID's. The encoding of the short channel ID's is +// detailed in the query message ensuring that the receiver knows how to +// properly decode each encode short channel ID which may be encoded using a +// compression format. The receiver should respond with a series of channel +// announcement and channel updates, finally sending a ReplyShortChanIDsEnd +// message. +type QueryShortChanIDs struct { + // ChainHash denotes the target chain that we're querying for the + // channel channel ID's of. + ChainHash chainhash.Hash + + // EncodingType is a signal to the receiver of the message that + // indicates exactly how the set of short channel ID's that follow have + // been encoded. + EncodingType ShortChanIDEncoding + + // ShortChanIDs is a slice of decoded short channel ID's. + ShortChanIDs []ShortChannelID +} + +// NewQueryShortChanIDs creates a new QueryShortChanIDs message. +func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, + s []ShortChannelID) *QueryShortChanIDs { + + return &QueryShortChanIDs{ + ChainHash: h, + EncodingType: e, + ShortChanIDs: s, + } +} + +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.Message interface. +var _ Message = (*QueryShortChanIDs)(nil) + +// Decode deserializes a serialized QueryShortChanIDs message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { + err := readElements(r, q.ChainHash[:]) + if err != nil { + return err + } + + q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + + return err +} + +// decodeShortChanIDs decodes a set of short channel ID's that have been +// encoded. The first byte of the body details how the short chan ID's were +// encoded. We'll use this type to govern exactly how we go about encoding the +// set of short channel ID's. +func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { + // First, we'll attempt to read the number of bytes in the body of the + // set of encoded short channel ID's. + var numBytesResp uint16 + err := readElements(r, &numBytesResp) + if err != nil { + return 0, nil, err + } + + queryBody := make([]byte, numBytesResp) + if _, err := io.ReadFull(r, queryBody); err != nil { + return 0, nil, err + } + + // The first byte is the encoding type, so we'll extract that so we can + // continue our parsing. + encodingType := ShortChanIDEncoding(queryBody[0]) + + // Before continuing, we'll snip off the first byte of the query body + // as that was just the encoding type. + queryBody = queryBody[1:] + + // If after extracting the encoding type, then number of remaining + // bytes instead a whole multiple of the size of an encoded short + // channel ID (8 bytes), then we'll return a parsing error. + if len(queryBody)%8 != 0 { + return 0, nil, fmt.Errorf("whole number of short chan ID's "+ + "cannot be encoded in len=%v", len(queryBody)) + } + + // Otherwise, depending on the encoding type, we'll decode the encode + // short channel ID's in a different manner. + switch encodingType { + + // In this encoding, we'll simply read a sort array of encoded short + // channel ID's from the buffer. + case EncodingSortedPlain: + // As each short channel ID is encoded as 8 bytes, we can + // compute the number of bytes encoded based on the size of the + // query body. + numShortChanIDs := len(queryBody) / 8 + shortChanIDs := make([]ShortChannelID, numShortChanIDs) + + // Finally, we'll read out the exact number of short channel + // ID's to conclude our parsing. + bodyReader := bytes.NewReader(queryBody) + for i := 0; i < numShortChanIDs; i++ { + if err := readElements(bodyReader, &shortChanIDs[i]); err != nil { + return 0, nil, fmt.Errorf("unable to parse "+ + "short chan ID: %v", err) + } + } + + return encodingType, shortChanIDs, nil + + default: + // If we've been sent an encoding type that we don't know of, + // then we'll return a parsing error as we can't continue if + // we're unable to encode them. + return 0, nil, ErrUnknownShortChanIDEncoding(encodingType) + } +} + +// Encode serializes the target QueryShortChanIDs into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { + // First, we'll write out the chain hash. + err := writeElements(w, q.ChainHash[:]) + if err != nil { + return err + } + + // Base on our encoding type, we'll write out the set of short channel + // ID's. + return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs) +} + +// encodeShortChanIDs encodes the passed short channel ID's into the passed +// io.Writer, respecting the specified encoding type. +func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, + shortChanIDs []ShortChannelID) error { + + switch encodingType { + + // In this encoding, we'll simply write a sorted array of encoded short + // channel ID's from the buffer. + case EncodingSortedPlain: + // First, we'll write out the number of bytes of the query + // body. We add 1 as the response will have the encoding type + // prepended to it. + numBytesBody := uint16(len(shortChanIDs)*8) + 1 + if err := writeElements(w, numBytesBody); err != nil { + return err + } + + // We'll then write out the encoding that that follows the + // actual encoded short channel ID's. + if err := writeElements(w, encodingType); err != nil { + return err + } + + // Next, we'll ensure that the set of short channel ID's is + // properly sorted in place. + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + + // Now that we know they're sorted, we can write out each short + // channel ID to the buffer. + for _, chanID := range shortChanIDs { + if err := writeElements(w, chanID); err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + return nil + + default: + // If we're trying to encode with an encoding type that we + // don't know of, then we'll return a parsing error as we can't + // continue if we're unable to encode them. + return ErrUnknownShortChanIDEncoding(encodingType) + } +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) MsgType() MessageType { + return MsgQueryShortChanIDs +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// QueryShortChanIDs complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go new file mode 100644 index 00000000..ac040c9c --- /dev/null +++ b/lnwire/reply_channel_range.go @@ -0,0 +1,84 @@ +package lnwire + +import "io" + +// ReplyChannelRange is the response to the QueryChannelRange message. It +// includes the original query, and the next streaming chunk of encoded short +// channel ID's as the response. We'll also include a byte that indicates if +// this is the last query in the message. +type ReplyChannelRange struct { + // QueryChannelRange is the corresponding query to this response. + QueryChannelRange + + // Complete denotes if this is the conclusion of the set of streaming + // responses to the original query. + Complete uint8 + + // EncodingType is a signal to the receiver of the message that + // indicates exactly how the set of short channel ID's that follow have + // been encoded. + EncodingType ShortChanIDEncoding + + // ShortChanIDs is a slice of decoded short channel ID's. + ShortChanIDs []ShortChannelID +} + +// NewReplyChannelRange creates a new empty ReplyChannelRange message. +func NewReplyChannelRange() *ReplyChannelRange { + return &ReplyChannelRange{} +} + +// A compile time check to ensure ReplyChannelRange implements the +// lnwire.Message interface. +var _ Message = (*ReplyChannelRange)(nil) + +// Decode deserializes a serialized ReplyChannelRange message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { + err := c.QueryChannelRange.Decode(r, pver) + if err != nil { + return err + } + + if err := readElements(r, &c.Complete); err != nil { + return err + } + + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) + + return err +} + +// Encode serializes the target ReplyChannelRange into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { + if err := c.QueryChannelRange.Encode(w, pver); err != nil { + return err + } + + if err := writeElements(w, c.Complete); err != nil { + return err + } + + return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) MsgType() MessageType { + return MsgReplyChannelRange +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// ReplyChannelRange complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go new file mode 100644 index 00000000..1fd2f849 --- /dev/null +++ b/lnwire/reply_short_chan_ids_end.go @@ -0,0 +1,74 @@ +package lnwire + +import ( + "io" + + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// ReplyShortChanIDsEnd is a message that marks the end of a streaming message +// response to an initial QueryShortChanIDs message. This marks that the +// receiver of the original QueryShortChanIDs for the target chain has either +// sent all adequate responses it knows of, or doesn't now of any short chan +// ID's for the target chain. +type ReplyShortChanIDsEnd struct { + // ChainHash denotes the target chain that we're respond to a short + // chan ID query for. + ChainHash chainhash.Hash + + // Complete will be set to 0 if we don't know of the chain that the + // remote peer sent their query for. Otherwise, we'll set this to 1 in + // order to indicate that we've sent all known responses for the prior + // set of short chan ID's in the corresponding QueryShortChanIDs + // message. + Complete uint8 +} + +// NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message. +func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd { + return &ReplyShortChanIDsEnd{} +} + +// A compile time check to ensure ReplyShortChanIDsEnd implements the +// lnwire.Message interface. +var _ Message = (*ReplyShortChanIDsEnd)(nil) + +// Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { + return readElements(r, + c.ChainHash[:], + &c.Complete, + ) +} + +// Encode serializes the target ReplyShortChanIDsEnd into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { + return writeElements(w, + c.ChainHash[:], + c.Complete, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) MsgType() MessageType { + return MsgReplyShortChanIDsEnd +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// ReplyShortChanIDsEnd complete message observing the specified protocol +// version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) MaxPayloadLength(uint32) uint32 { + // 32 (chain hash) + 1 (complete) + return 33 +} diff --git a/peer.go b/peer.go index 3606f3de..3ed9ae63 100644 --- a/peer.go +++ b/peer.go @@ -613,7 +613,6 @@ func (p *peer) readNextMessage() (lnwire.Message, error) { return nil, err } - // TODO(roasbeef): add message summaries p.logWireMessage(nextMsg, true) return nextMsg, nil @@ -995,7 +994,12 @@ out: case *lnwire.ChannelUpdate, *lnwire.ChannelAnnouncement, *lnwire.NodeAnnouncement, - *lnwire.AnnounceSignatures: + *lnwire.AnnounceSignatures, + *lnwire.GossipTimestampRange, + *lnwire.QueryShortChanIDs, + *lnwire.QueryChannelRange, + *lnwire.ReplyChannelRange, + *lnwire.ReplyShortChanIDsEnd: discStream.AddMsg(msg) @@ -1139,6 +1143,30 @@ func messageSummary(msg lnwire.Message) string { case *lnwire.ChannelReestablish: return fmt.Sprintf("next_local_height=%v, remote_tail_height=%v", msg.NextLocalCommitHeight, msg.RemoteCommitTailHeight) + + case *lnwire.ReplyShortChanIDsEnd: + return fmt.Sprintf("chain_hash=%v, complete=%v", msg.ChainHash, + msg.Complete) + + case *lnwire.ReplyChannelRange: + return fmt.Sprintf("complete=%v, encoding=%v, num_chans=%v", + msg.Complete, msg.EncodingType, len(msg.ShortChanIDs)) + + case *lnwire.QueryShortChanIDs: + return fmt.Sprintf("chain_hash=%v, encoding=%v, num_chans=%v", + msg.ChainHash, msg.EncodingType, len(msg.ShortChanIDs)) + + case *lnwire.QueryChannelRange: + return fmt.Sprintf("chain_hash=%v, start_height=%v, "+ + "num_blocks=%v", msg.ChainHash, msg.FirstBlockHeight, + msg.NumBlocks) + + case *lnwire.GossipTimestampRange: + return fmt.Sprintf("chain_hash=%v, first_stamp=%v, "+ + "stamp_range=%v", msg.ChainHash, + time.Unix(int64(msg.FirstTimestamp), 0), + msg.TimestampRange) + } return "" @@ -1213,7 +1241,6 @@ func (p *peer) writeMessage(msg lnwire.Message) error { return ErrPeerExiting } - // TODO(roasbeef): add message summaries p.logWireMessage(msg, false) // We'll re-slice of static write buffer to allow this new message to diff --git a/server.go b/server.go index 09f7b90a..49027332 100644 --- a/server.go +++ b/server.go @@ -390,6 +390,7 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl, Notifier: s.cc.chainNotifier, ChainHash: *activeNetParams.GenesisHash, Broadcast: s.BroadcastMessage, + ChanSeries: &chanSeries{s.chanDB.ChannelGraph()}, SendToPeer: s.SendToPeer, NotifyWhenOnline: s.NotifyWhenOnline, ProofMatureDelta: 0, @@ -1304,6 +1305,10 @@ func (s *server) peerTerminationWatcher(p *peer) { // available for use. s.fundingMgr.CancelPeerReservations(p.PubKey()) + // We'll also inform the gossiper that this peer is no longer active, + // so we don't need to maintain sync state for it any longer. + s.authGossiper.PruneSyncState(p.addr.IdentityKey) + // Tell the switch to remove all links associated with this peer. // Passing nil as the target link indicates that all links associated // with this interface should be closed. @@ -1465,9 +1470,16 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, // feature vector to advertise to the remote node. localFeatures := lnwire.NewRawFeatureVector() - // We'll only request a full channel graph sync if we detect that + // We'll signal that we understand the data loss protection feature, + // and also that we support the new gossip query features. + localFeatures.Set(lnwire.DataLossProtectOptional) + localFeatures.Set(lnwire.GossipQueriesOptional) + + // We'll only request a full channel graph sync if we detect that that // we aren't fully synced yet. if s.shouldRequestGraphSync() { + // TODO(roasbeef): only do so if gossiper doesn't have active + // peers? localFeatures.Set(lnwire.InitialRoutingSync) } @@ -1779,10 +1791,30 @@ func (s *server) addPeer(p *peer) { s.wg.Add(1) go s.peerTerminationWatcher(p) + switch { + // If the remote peer knows of the new gossip queries feature, then + // we'll create a new gossipSyncer in the AuthenticatedGossiper for it. + case p.remoteLocalFeatures.HasFeature(lnwire.GossipQueriesOptional): + srvrLog.Infof("Negotiated chan series queries with %x", + p.pubKeyBytes[:]) + + // We'll only request channel updates from the remote peer if + // its enabled in the config, or we're already getting updates + // from enough peers. + // + // TODO(roasbeef): craft s.t. we only get updates from a few + // peers + recvUpdates := !cfg.NoChanUpdates + go s.authGossiper.InitSyncState(p.addr.IdentityKey, recvUpdates) + // If the remote peer has the initial sync feature bit set, then we'll // being the synchronization protocol to exchange authenticated channel - // graph edges/vertexes - if p.remoteLocalFeatures.HasFeature(lnwire.InitialRoutingSync) { + // graph edges/vertexes, but only if they don't know of the new gossip + // queries. + case p.remoteLocalFeatures.HasFeature(lnwire.InitialRoutingSync): + srvrLog.Infof("Requesting full table sync with %x", + p.pubKeyBytes[:]) + go s.authGossiper.SynchronizeNode(p.addr.IdentityKey) }