diff --git a/routing/missioncontrol.go b/routing/missioncontrol.go index 44948924..418fa0a9 100644 --- a/routing/missioncontrol.go +++ b/routing/missioncontrol.go @@ -47,7 +47,7 @@ type missionControl struct { // it was added to the prune view. Edges are added to this map if a // caller reports to missionControl a failure localized to that edge // when sending a payment. - failedEdges map[uint64]time.Time + failedEdges map[edgeLocator]time.Time // failedVertexes maps a node's public key that should be pruned, to // the time that it was added to the prune view. Vertexes are added to @@ -76,7 +76,7 @@ func newMissionControl(g *channeldb.ChannelGraph, selfNode *channeldb.LightningN qb func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) *missionControl { return &missionControl{ - failedEdges: make(map[uint64]time.Time), + failedEdges: make(map[edgeLocator]time.Time), failedVertexes: make(map[Vertex]time.Time), selfNode: selfNode, queryBandwidth: qb, @@ -90,7 +90,7 @@ func newMissionControl(g *channeldb.ChannelGraph, selfNode *channeldb.LightningN // state of the wider network from the PoV of mission control compiled via HTLC // routing attempts in the past. type graphPruneView struct { - edges map[uint64]struct{} + edges map[edgeLocator]struct{} vertexes map[Vertex]struct{} } @@ -125,7 +125,7 @@ func (m *missionControl) GraphPruneView() graphPruneView { // We'll also do the same for edges, but use the edgeDecay this time // rather than the decay for vertexes. - edges := make(map[uint64]struct{}) + edges := make(map[edgeLocator]struct{}) for edge, pruneTime := range m.failedEdges { if now.Sub(pruneTime) >= edgeDecay { log.Tracef("Pruning decayed failure report for edge %v "+ @@ -164,11 +164,11 @@ type paymentSession struct { bandwidthHints map[uint64]lnwire.MilliSatoshi - // errFailedFeeChans is a map of the short channel ID's that were the + // errFailedFeeChans is a map of the short channel IDs that were the // source of policy related routing failures during this payment attempt. // We'll use this map to prune out channels when the first error may not // require pruning, but any subsequent ones do. - errFailedPolicyChans map[uint64]struct{} + errFailedPolicyChans map[edgeLocator]struct{} mc *missionControl @@ -245,7 +245,7 @@ func (m *missionControl) NewPaymentSession(routeHints [][]HopHint, pruneViewSnapshot: viewSnapshot, additionalEdges: edges, bandwidthHints: bandwidthHints, - errFailedPolicyChans: make(map[uint64]struct{}), + errFailedPolicyChans: make(map[edgeLocator]struct{}), mc: m, }, nil } @@ -259,7 +259,7 @@ func (m *missionControl) NewPaymentSessionFromRoutes(routes []*Route) *paymentSe pruneViewSnapshot: m.GraphPruneView(), haveRoutes: true, preBuiltRoutes: routes, - errFailedPolicyChans: make(map[uint64]struct{}), + errFailedPolicyChans: make(map[edgeLocator]struct{}), mc: m, } } @@ -325,17 +325,17 @@ func (p *paymentSession) ReportVertexFailure(v Vertex) { // retrying an edge after its pruning has expired. // // TODO(roasbeef): also add value attempted to send and capacity of channel -func (p *paymentSession) ReportChannelFailure(e uint64) { +func (p *paymentSession) ReportEdgeFailure(e *edgeLocator) { log.Debugf("Reporting edge %v failure to Mission Control", e) // First, we'll add the failed edge to our local prune view snapshot. - p.pruneViewSnapshot.edges[e] = struct{}{} + p.pruneViewSnapshot.edges[*e] = struct{}{} // With the edge added, we'll now report back to the global prune view, // with this new piece of information so it can be utilized for new // payment sessions. p.mc.Lock() - p.mc.failedEdges[e] = time.Now() + p.mc.failedEdges[*e] = time.Now() p.mc.Unlock() } @@ -345,12 +345,12 @@ func (p *paymentSession) ReportChannelFailure(e uint64) { // edge as 'policy failed once'. The next time it fails, the whole node will be // pruned. This is to prevent nodes from keeping us busy by continuously sending // new channel updates. -func (p *paymentSession) ReportChannelPolicyFailure( - errSource Vertex, failedChanID uint64) { +func (p *paymentSession) ReportEdgePolicyFailure( + errSource Vertex, failedEdge *edgeLocator) { // Check to see if we've already reported a policy related failure for // this channel. If so, then we'll prune out the vertex. - _, ok := p.errFailedPolicyChans[failedChanID] + _, ok := p.errFailedPolicyChans[*failedEdge] if ok { // TODO(joostjager): is this aggresive pruning still necessary? // Just pruning edges may also work unless there is a huge @@ -361,7 +361,7 @@ func (p *paymentSession) ReportChannelPolicyFailure( } // Finally, we'll record a policy failure from this node and move on. - p.errFailedPolicyChans[failedChanID] = struct{}{} + p.errFailedPolicyChans[*failedEdge] = struct{}{} } // RequestRoute returns a route which is likely to be capable for successfully @@ -442,7 +442,7 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, // if no payment attempts have been made. func (m *missionControl) ResetHistory() { m.Lock() - m.failedEdges = make(map[uint64]time.Time) + m.failedEdges = make(map[edgeLocator]time.Time) m.failedVertexes = make(map[Vertex]time.Time) m.Unlock() } diff --git a/routing/pathfind.go b/routing/pathfind.go index 8bedfe3c..f2e0b134 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -440,7 +440,7 @@ type restrictParams struct { // ignoredEdges is an optional set of edges that should be ignored if // encountered during path finding. - ignoredEdges map[uint64]struct{} + ignoredEdges map[edgeLocator]struct{} // feeLimit is a maximum fee amount allowed to be used on the path from // the source to the target. @@ -567,7 +567,9 @@ func findPath(g *graphParams, r *restrictParams, if _, ok := r.ignoredNodes[fromVertex]; ok { return } - if _, ok := r.ignoredEdges[edge.ChannelID]; ok { + + locator := newEdgeLocator(edge) + if _, ok := r.ignoredEdges[*locator]; ok { return } @@ -795,7 +797,7 @@ func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph, amt lnwire.MilliSatoshi, feeLimit lnwire.MilliSatoshi, numPaths uint32, bandwidthHints map[uint64]lnwire.MilliSatoshi) ([][]*channeldb.ChannelEdgePolicy, error) { - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) // TODO(roasbeef): modifying ordering within heap to eliminate final @@ -850,7 +852,7 @@ func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph, // we'll exclude from the next path finding attempt. // These are required to ensure the paths are unique // and loopless. - ignoredEdges = make(map[uint64]struct{}) + ignoredEdges = make(map[edgeLocator]struct{}) ignoredVertexes = make(map[Vertex]struct{}) // Our spur node is the i-th node in the prior shortest @@ -868,8 +870,11 @@ func findPaths(tx *bbolt.Tx, graph *channeldb.ChannelGraph, // shortest path, then we'll remove the edge // directly _after_ our spur node from the // graph so we don't repeat paths. - if len(path) > i+1 && isSamePath(rootPath, path[:i+1]) { - ignoredEdges[path[i+1].ChannelID] = struct{}{} + if len(path) > i+1 && + isSamePath(rootPath, path[:i+1]) { + + locator := newEdgeLocator(path[i+1]) + ignoredEdges[*locator] = struct{}{} } } diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index df41825e..383e0c2f 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -575,7 +575,7 @@ func TestFindLowestFeePath(t *testing.T) { } sourceVertex := Vertex(sourceNode.PubKeyBytes) - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) const ( @@ -721,7 +721,7 @@ func testBasicGraphPathFindingCase(t *testing.T, graphInstance *testGraphInstanc } sourceVertex := Vertex(sourceNode.PubKeyBytes) - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) const ( @@ -1255,7 +1255,7 @@ func TestNewRoutePathTooLong(t *testing.T) { t.Fatalf("unable to fetch source node: %v", err) } - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) paymentAmt := lnwire.NewMSatFromSatoshis(100) @@ -1314,7 +1314,7 @@ func TestPathNotAvailable(t *testing.T) { t.Fatalf("unable to fetch source node: %v", err) } - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) // With the test graph loaded, we'll test that queries for target that @@ -1359,7 +1359,7 @@ func TestPathInsufficientCapacity(t *testing.T) { if err != nil { t.Fatalf("unable to fetch source node: %v", err) } - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) // Next, test that attempting to find a path in which the current @@ -1404,7 +1404,7 @@ func TestRouteFailMinHTLC(t *testing.T) { if err != nil { t.Fatalf("unable to fetch source node: %v", err) } - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) // We'll not attempt to route an HTLC of 10 SAT from roasbeef to Son @@ -1446,7 +1446,7 @@ func TestRouteFailDisabledEdge(t *testing.T) { if err != nil { t.Fatalf("unable to fetch source node: %v", err) } - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) // First, we'll try to route from roasbeef -> sophon. This should @@ -1546,7 +1546,7 @@ func TestPathSourceEdgesBandwidth(t *testing.T) { if err != nil { t.Fatalf("unable to fetch source node: %v", err) } - ignoredEdges := make(map[uint64]struct{}) + ignoredEdges := make(map[edgeLocator]struct{}) ignoredVertexes := make(map[Vertex]struct{}) // First, we'll try to route from roasbeef -> sophon. This should diff --git a/routing/router.go b/routing/router.go index eb0ec879..0e1a9b26 100644 --- a/routing/router.go +++ b/routing/router.go @@ -213,6 +213,45 @@ func newRouteTuple(amt lnwire.MilliSatoshi, dest []byte) routeTuple { return r } +// edgeLocator is a struct used to identify a specific edge. The direction +// fields takes the value of 0 or 1 and is identical in definition to the +// channel direction flag. A value of 0 means the direction from the lower node +// pubkey to the higher. +type edgeLocator struct { + channelID uint64 + direction uint8 +} + +// newEdgeLocatorByPubkeys returns an edgeLocator based on its end point +// pubkeys. +func newEdgeLocatorByPubkeys(channelID uint64, fromNode, toNode *Vertex) *edgeLocator { + // Determine direction based on lexicographical ordering of both + // pubkeys. + var direction uint8 + if bytes.Compare(fromNode[:], toNode[:]) == 1 { + direction = 1 + } + + return &edgeLocator{ + channelID: channelID, + direction: direction, + } +} + +// newEdgeLocator extracts an edgeLocator based for a full edge policy +// structure. +func newEdgeLocator(edge *channeldb.ChannelEdgePolicy) *edgeLocator { + return &edgeLocator{ + channelID: edge.ChannelID, + direction: uint8(edge.Flags & lnwire.ChanUpdateDirection), + } +} + +// String returns a human readable version of the edgeLocator values. +func (e *edgeLocator) String() string { + return fmt.Sprintf("%v:%v", e.channelID, e.direction) +} + // ChannelRouter is the layer 3 router within the Lightning stack. Below the // ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain // itself. The primary role of the ChannelRouter is to respond to queries for @@ -1755,14 +1794,11 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, errVertex := NewVertex(errSource) log.Tracef("node=%x reported failure when sending "+ - "htlc=%x", errSource.SerializeCompressed(), - payment.PaymentHash[:]) + "htlc=%x", errVertex, payment.PaymentHash[:]) // Always determine chan id ourselves, because a channel // update with id may not be available. - failedChanID, err := getFailedChannelID( - route, errVertex, - ) + failedEdge, err := getFailedEdge(route, errVertex) if err != nil { return preImage, nil, err } @@ -1793,13 +1829,13 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, // Or is there a valid reason for the channel // update to fail? if !updateOk { - paySession.ReportChannelFailure( - failedChanID, + paySession.ReportEdgeFailure( + failedEdge, ) } - paySession.ReportChannelPolicyFailure( - NewVertex(errSource), failedChanID, + paySession.ReportEdgePolicyFailure( + NewVertex(errSource), failedEdge, ) } @@ -1889,7 +1925,7 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, // the update and continue. case *lnwire.FailChannelDisabled: r.applyChannelUpdate(&onionErr.Update, errSource) - paySession.ReportChannelFailure(failedChanID) + paySession.ReportEdgeFailure(failedEdge) continue // It's likely that the outgoing channel didn't have @@ -1897,7 +1933,7 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, // now, and continue onwards with our path finding. case *lnwire.FailTemporaryChannelFailure: r.applyChannelUpdate(onionErr.Update, errSource) - paySession.ReportChannelFailure(failedChanID) + paySession.ReportEdgeFailure(failedEdge) continue // If the send fail due to a node not having the @@ -1922,7 +1958,7 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, // returning errors in order to attempt to black list // another node. case *lnwire.FailUnknownNextPeer: - paySession.ReportChannelFailure(failedChanID) + paySession.ReportEdgeFailure(failedEdge) continue // If the node wasn't able to forward for which ever @@ -1950,10 +1986,17 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, continue // If we get a permanent channel or node failure, then - // we'll note this (exclude the vertex/edge), and + // we'll prune the channel in both directions and // continue with the rest of the routes. case *lnwire.FailPermanentChannelFailure: - paySession.ReportChannelFailure(failedChanID) + paySession.ReportEdgeFailure(&edgeLocator{ + channelID: failedEdge.channelID, + direction: 0, + }) + paySession.ReportEdgeFailure(&edgeLocator{ + channelID: failedEdge.channelID, + direction: 1, + }) continue default: @@ -1965,43 +2008,43 @@ func (r *ChannelRouter) sendPayment(payment *LightningPayment, } } -// getFailedChannelID tries to locate the failing channel given a route and the +// getFailedEdge tries to locate the failing channel given a route and the // pubkey of the node that sent the error. It will assume that the error is // associated with the outgoing channel of the error node. -func getFailedChannelID(route *Route, errSource Vertex) ( - uint64, error) { - - // If the error originates from ourselves, report our outgoing channel - // as failing. - if errSource == route.SourcePubKey { - return route.Hops[0].ChannelID, nil - } +func getFailedEdge(route *Route, errSource Vertex) ( + *edgeLocator, error) { hopCount := len(route.Hops) + fromNode := route.SourcePubKey for i, hop := range route.Hops { - if errSource != hop.PubKeyBytes { - continue - } + toNode := hop.PubKeyBytes - // If the errSource is the final hop, we assume that the - // failing channel is the incoming channel. + // Determine if we have a failure from the final hop. // - // TODO(joostjager): In this case, certain types of - // errors are not expected. For example - // FailUnknownNextPeer. This could be a reason to prune - // the node? - if i == hopCount-1 { - return route.Hops[i].ChannelID, nil + // TODO(joostjager): In this case, certain types of errors are + // not expected. For example FailUnknownNextPeer. This could be + // a reason to prune the node? + finalHopFailing := i == hopCount-1 && errSource == toNode + + // As this error indicates that the target channel was unable to + // carry this HTLC (for w/e reason), we'll return the _outgoing_ + // channel that the source of the error was meant to pass the + // HTLC along to. + // + // If the errSource is the final hop, we assume that the failing + // channel is the incoming channel. + if errSource == fromNode || finalHopFailing { + return newEdgeLocatorByPubkeys( + hop.ChannelID, + &fromNode, + &toNode, + ), nil } - // As this error indicates that the target channel was - // unable to carry this HTLC (for w/e reason), we'll - // query return the _outgoing_ channel that the source - // of the error was meant to pass the HTLC along to. - return route.Hops[i+1].ChannelID, nil + fromNode = toNode } - return 0, fmt.Errorf("cannot find error source node in route") + return nil, fmt.Errorf("cannot find error source node in route") } // applyChannelUpdate validates a channel update and if valid, applies it to the diff --git a/routing/router_test.go b/routing/router_test.go index c4b8a0a4..c184c3ce 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -1859,7 +1859,7 @@ func TestFindPathFeeWeighting(t *testing.T) { } ignoreVertex := make(map[Vertex]struct{}) - ignoreEdge := make(map[uint64]struct{}) + ignoreEdge := make(map[edgeLocator]struct{}) amt := lnwire.MilliSatoshi(100)