diff --git a/lnwire/msat.go b/lnwire/msat.go index 67ee4972..d3789dfa 100644 --- a/lnwire/msat.go +++ b/lnwire/msat.go @@ -6,9 +6,15 @@ import ( "github.com/btcsuite/btcutil" ) -// mSatScale is a value that's used to scale satoshis to milli-satoshis, and -// the other way around. -const mSatScale uint64 = 1000 +const ( + // mSatScale is a value that's used to scale satoshis to milli-satoshis, and + // the other way around. + mSatScale uint64 = 1000 + + // MaxMilliSatoshi is the maximum number of msats that can be expressed + // in this data type. + MaxMilliSatoshi = ^MilliSatoshi(0) +) // MilliSatoshi are the native unit of the Lightning Network. A milli-satoshi // is simply 1/1000th of a satoshi. There are 1000 milli-satoshis in a single diff --git a/routing/pathfind.go b/routing/pathfind.go index 743adad1..c3e335ce 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -392,32 +392,11 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // processEdge is a helper closure that will be used to make sure edges // satisfy our specific requirements. - processEdge := func(fromVertex route.Vertex, bandwidth lnwire.MilliSatoshi, + processEdge := func(fromVertex route.Vertex, edge *channeldb.ChannelEdgePolicy, toNodeDist *nodeWithDist) { edgesExpanded++ - // If this is not a local channel and it is disabled, we will - // skip it. - // TODO(halseth): also ignore disable flags for non-local - // channels if bandwidth hint is set? - isSourceChan := fromVertex == source - - edgeFlags := edge.ChannelFlags - isDisabled := edgeFlags&lnwire.ChanUpdateDisabled != 0 - - if !isSourceChan && isDisabled { - return - } - - // If we have an outgoing channel restriction and this is not - // the specified channel, skip it. - if isSourceChan && r.OutgoingChannelID != nil && - *r.OutgoingChannelID != edge.ChannelID { - - return - } - // Calculate amount that the candidate node would have to sent // out. amountToSend := toNodeDist.amountToReceive @@ -438,25 +417,6 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, return } - // If the estimated bandwidth of the channel edge is not able - // to carry the amount that needs to be send, return. - if bandwidth < amountToSend { - return - } - - // If the amountToSend is less than the minimum required - // amount, return. - if amountToSend < edge.MinHTLC { - return - } - - // If this edge was constructed from a hop hint, we won't have access to - // its max HTLC. Therefore, only consider discarding this edge here if - // the field is set. - if edge.MaxHTLC != 0 && edge.MaxHTLC < amountToSend { - return - } - // Compute fee that fromVertex is charging. It is based on the // amount that needs to be sent to the next node in the route. // @@ -585,67 +545,34 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, break } - cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, _, - inEdge *channeldb.ChannelEdgePolicy) error { + // Create unified policies for all incoming connections. + u := newUnifiedPolicies(source, pivot, r.OutgoingChannelID) - // If there is no edge policy for this candidate - // node, skip. Note that we are searching backwards - // so this node would have come prior to the pivot - // node in the route. - if inEdge == nil { - return nil - } - - // We'll query the lower layer to see if we can obtain - // any more up to date information concerning the - // bandwidth of this edge. - edgeBandwidth, ok := g.bandwidthHints[edgeInfo.ChannelID] - if !ok { - // If we don't have a hint for this edge, then - // we'll just use the known Capacity/MaxHTLC as - // the available bandwidth. It's possible for - // the capacity to be unknown when operating - // under a light client. - edgeBandwidth = inEdge.MaxHTLC - if edgeBandwidth == 0 { - edgeBandwidth = lnwire.NewMSatFromSatoshis( - edgeInfo.Capacity, - ) - } - } - - // Before we can process the edge, we'll need to fetch - // the node on the _other_ end of this channel as we - // may later need to iterate over the incoming edges of - // this node if we explore it further. - chanSource, err := edgeInfo.OtherNodeKeyBytes(pivot[:]) - if err != nil { - return err - } - - // Check if this candidate node is better than what we - // already have. - processEdge(chanSource, edgeBandwidth, inEdge, partialPath) - return nil - } - - // Now that we've found the next potential step to take we'll - // examine all the incoming edges (channels) from this node to - // further our graph traversal. - err := g.graph.ForEachNodeChannel(tx, pivot[:], cb) + err := u.addGraphPolicies(g.graph, tx) if err != nil { return nil, err } - // Then, we'll examine all the additional edges from the node - // we're currently visiting. Since we don't know the capacity - // of the private channel, we'll assume it was selected as a - // routing hint due to having enough capacity for the payment - // and use the payment amount as its capacity. - bandWidth := partialPath.amountToReceive for _, reverseEdge := range additionalEdgesWithSrc[pivot] { - processEdge(reverseEdge.sourceNode, bandWidth, - reverseEdge.edge, partialPath) + u.addPolicy(reverseEdge.sourceNode, reverseEdge.edge, 0) + } + + amtToSend := partialPath.amountToReceive + + // Expand all connections using the optimal policy for each + // connection. + for fromNode, unifiedPolicy := range u.policies { + policy := unifiedPolicy.getPolicy( + amtToSend, g.bandwidthHints, + ) + + if policy == nil { + continue + } + + // Check if this candidate node is better than what we + // already have. + processEdge(fromNode, policy, partialPath) } } diff --git a/routing/router.go b/routing/router.go index 3908a93d..ceacd351 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2254,90 +2254,6 @@ func generateBandwidthHints(sourceNode *channeldb.LightningNode, return bandwidthHints, nil } -// runningAmounts keeps running amounts while the route is traversed. -type runningAmounts struct { - // amt is the intended amount to send via the route. - amt lnwire.MilliSatoshi - - // max is the running maximum that the route can carry. - max lnwire.MilliSatoshi -} - -// prependChannel returns a new set of running amounts that would result from -// prepending the given channel to the route. If canIncreaseAmt is set, the -// amount may be increased if it is too small to satisfy the channel's minimum -// htlc amount. -func (r *runningAmounts) prependChannel(policy *channeldb.ChannelEdgePolicy, - capacity btcutil.Amount, localChan bool, canIncreaseAmt bool) ( - runningAmounts, error) { - - // Determine max htlc value. - maxHtlc := lnwire.NewMSatFromSatoshis(capacity) - if policy.MessageFlags.HasMaxHtlc() { - maxHtlc = policy.MaxHTLC - } - - amt := r.amt - - // If we have a specific amount for which we are building the route, - // validate it against the channel constraints and return the new - // running amount. - if !canIncreaseAmt { - if amt < policy.MinHTLC || amt > maxHtlc { - return runningAmounts{}, fmt.Errorf("channel htlc "+ - "constraints [%v - %v] violated with amt %v", - policy.MinHTLC, maxHtlc, amt) - } - - // Update running amount by adding the fee for non-local - // channels. - if !localChan { - amt += policy.ComputeFee(amt) - } - - return runningAmounts{ - amt: amt, - }, nil - } - - // Adapt the minimum amount to what this channel allows. - if policy.MinHTLC > r.amt { - amt = policy.MinHTLC - } - - // Update the maximum amount too to be able to detect incompatible - // channels. - max := r.max - if maxHtlc < r.max { - max = maxHtlc - } - - // If we get in the situation that the minimum amount exceeds the - // maximum amount (enforced further down stream), we have incompatible - // channel policies. - // - // There is possibility with pubkey addressing that we should have - // selected a different channel downstream, but we don't backtrack to - // try to fix that. It would complicate path finding while we expect - // this situation to be rare. The spec recommends to keep all policies - // towards a peer identical. If that is the case, there isn't a better - // channel that we should have selected. - if amt > max { - return runningAmounts{}, - fmt.Errorf("incompatible channel policies: %v "+ - "exceeds %v", amt, max) - } - - // Add fees to the running amounts. Skip the source node fees as - // those do not need to be paid. - if !localChan { - amt += policy.ComputeFee(amt) - max += policy.ComputeFee(max) - } - - return runningAmounts{amt: amt, max: max}, nil -} - // ErrNoChannel is returned when a route cannot be built because there are no // channels that satisfy all requirements. type ErrNoChannel struct { @@ -2374,24 +2290,21 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, return nil, err } - // Allocate a list that will contain the selected channels for this + // Allocate a list that will contain the unified policies for this // route. - edges := make([]*channeldb.ChannelEdgePolicy, len(hops)) + edges := make([]*unifiedPolicy, len(hops)) - // Keep a running amount and the maximum for this route. - amts := runningAmounts{ - max: lnwire.MilliSatoshi(^uint64(0)), - } + var runningAmt lnwire.MilliSatoshi if useMinAmt { // For minimum amount routes, aim to deliver at least 1 msat to // the destination. There are nodes in the wild that have a // min_htlc channel policy of zero, which could lead to a zero // amount payment being made. - amts.amt = 1 + runningAmt = 1 } else { // If an amount is specified, we need to build a route that // delivers exactly this amount to the final destination. - amts.amt = *amt + runningAmt = *amt } // Traverse hops backwards to accumulate fees in the running amounts. @@ -2408,142 +2321,85 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, localChan := i == 0 - // Iterate over candidate channels to select the channel - // to use for the final route. - var ( - bestEdge *channeldb.ChannelEdgePolicy - bestAmts *runningAmounts - bestBandwidth lnwire.MilliSatoshi - ) + // Build unified policies for this hop based on the channels + // known in the graph. + u := newUnifiedPolicies(source, toNode, outgoingChan) - cb := func(tx *bbolt.Tx, - edgeInfo *channeldb.ChannelEdgeInfo, - _, inEdge *channeldb.ChannelEdgePolicy) error { - - chanID := edgeInfo.ChannelID - - // Apply outgoing channel restriction is active. - if localChan && outgoingChan != nil && - chanID != *outgoingChan { - - return nil - } - - // No unknown policy channels. - if inEdge == nil { - return nil - } - - // Before we can process the edge, we'll need to - // fetch the node on the _other_ end of this - // channel as we may later need to iterate over - // the incoming edges of this node if we explore - // it further. - chanFromNode, err := edgeInfo.FetchOtherNode( - tx, toNode[:], - ) - if err != nil { - return err - } - - // Continue searching if this channel doesn't - // connect with the previous hop. - if chanFromNode.PubKeyBytes != fromNode { - return nil - } - - // Validate whether this channel's policy is satisfied - // and obtain the new running amounts if this channel - // was to be selected. - newAmts, err := amts.prependChannel( - inEdge, edgeInfo.Capacity, localChan, - useMinAmt, - ) - if err != nil { - log.Tracef("Skipping chan %v: %v", - inEdge.ChannelID, err) - - return nil - } - - // If we already have a best edge, check whether this - // edge is better. - bandwidth := bandwidthHints[chanID] - if bestEdge != nil { - if localChan { - // For local channels, better is defined - // as having more bandwidth. We try to - // maximize the chance that the returned - // route succeeds. - if bandwidth < bestBandwidth { - return nil - } - } else { - // For other channels, better is defined - // as lower fees for the amount to send. - // Normally all channels between two - // nodes should have the same policy, - // but in case not we minimize our cost - // here. Regular path finding would do - // the same. - if newAmts.amt > bestAmts.amt { - return nil - } - } - } - - // If we get here, the current edge is better. Replace - // the best. - bestEdge = inEdge - bestAmts = &newAmts - bestBandwidth = bandwidth - - return nil - } - - err := r.cfg.Graph.ForEachNodeChannel(nil, toNode[:], cb) + err := u.addGraphPolicies(r.cfg.Graph, nil) if err != nil { return nil, err } - // There is no matching channel. Stop building the route here. - if bestEdge == nil { + // Exit if there are no channels. + unifiedPolicy, ok := u.policies[fromNode] + if !ok { return nil, ErrNoChannel{ fromNode: fromNode, position: i, } } - log.Tracef("Select channel %v at position %v", bestEdge.ChannelID, i) + // If using min amt, increase amt if needed. + if useMinAmt { + min := unifiedPolicy.minAmt() + if min > runningAmt { + runningAmt = min + } + } - edges[i] = bestEdge - amts = *bestAmts + // Get a forwarding policy for the specific amount that we want + // to forward. + policy := unifiedPolicy.getPolicy(runningAmt, bandwidthHints) + if policy == nil { + return nil, ErrNoChannel{ + fromNode: fromNode, + position: i, + } + } + + // Add fee for this hop. + if !localChan { + runningAmt += policy.ComputeFee(runningAmt) + } + + log.Tracef("Select channel %v at position %v", policy.ChannelID, i) + + edges[i] = unifiedPolicy } + // Now that we arrived at the start of the route and found out the route + // total amount, we make a forward pass. Because the amount may have + // been increased in the backward pass, fees need to be recalculated and + // amount ranges re-checked. + var pathEdges []*channeldb.ChannelEdgePolicy + receiverAmt := runningAmt + for i, edge := range edges { + policy := edge.getPolicy(receiverAmt, bandwidthHints) + if policy == nil { + return nil, ErrNoChannel{ + fromNode: hops[i-1], + position: i, + } + } + + if i > 0 { + // Decrease the amount to send while going forward. + receiverAmt -= policy.ComputeFeeFromIncoming( + receiverAmt, + ) + } + + pathEdges = append(pathEdges, policy) + } + + // Build and return the final route. _, height, err := r.cfg.Chain.GetBestBlock() if err != nil { return nil, err } - var receiverAmt lnwire.MilliSatoshi - if useMinAmt { - // We've calculated the minimum amount for the htlc that the - // source node hands out. The newRoute call below expects the - // amount that must reach the receiver after subtraction of fees - // along the way. Iterate over all edges to calculate the - // receiver amount. - receiverAmt = amts.amt - for _, edge := range edges[1:] { - receiverAmt -= edge.ComputeFeeFromIncoming(receiverAmt) - } - } else { - // Deliver the specified amount to the receiver. - receiverAmt = *amt - } - - // Build and return the final route. return newRoute( - receiverAmt, source, edges, uint32(height), + receiverAmt, source, pathEdges, uint32(height), uint16(finalCltvDelta), nil, ) } diff --git a/routing/router_test.go b/routing/router_test.go index b2b9e79c..9ea9737c 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3410,6 +3410,8 @@ func TestBuildRoute(t *testing.T) { defer cleanUp() checkHops := func(rt *route.Route, expected []uint64) { + t.Helper() + if len(rt.Hops) != len(expected) { t.Fatal("hop count mismatch") } @@ -3437,10 +3439,10 @@ func TestBuildRoute(t *testing.T) { } // Check that we get the expected route back. The total amount should be - // the amount to deliver to hop c (100 sats) plus the fee for hop b (5 - // sats). - checkHops(rt, []uint64{1, 2}) - if rt.TotalAmount != 105000 { + // the amount to deliver to hop c (100 sats) plus the max fee for the + // connection b->c (6 sats). + checkHops(rt, []uint64{1, 7}) + if rt.TotalAmount != 106000 { t.Fatalf("unexpected total amount %v", rt.TotalAmount) } @@ -3453,11 +3455,11 @@ func TestBuildRoute(t *testing.T) { } // Check that we get the expected route back. The minimum that we can - // send from b to c is 20 sats. Hop b charges 1 sat for the forwarding. - // The channel between hop a and b can carry amounts in the range [5, - // 100], so 21 sats is the minimum amount for this route. - checkHops(rt, []uint64{1, 2}) - if rt.TotalAmount != 21000 { + // send from b to c is 20 sats. Hop b charges 1200 msat for the + // forwarding. The channel between hop a and b can carry amounts in the + // range [5, 100], so 21200 msats is the minimum amount for this route. + checkHops(rt, []uint64{1, 7}) + if rt.TotalAmount != 21200 { t.Fatalf("unexpected total amount %v", rt.TotalAmount) } diff --git a/routing/unified_policies.go b/routing/unified_policies.go new file mode 100644 index 00000000..81e646c2 --- /dev/null +++ b/routing/unified_policies.go @@ -0,0 +1,281 @@ +package routing + +import ( + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// unifiedPolicies holds all unified policies for connections towards a node. +type unifiedPolicies struct { + // policies contains a unified policy for every from node. + policies map[route.Vertex]*unifiedPolicy + + // sourceNode is the sender of a payment. The rules to pick the final + // policy are different for local channels. + sourceNode route.Vertex + + // toNode is the node for which the unified policies are instantiated. + toNode route.Vertex + + // outChanRestr is an optional outgoing channel restriction for the + // local channel to use. + outChanRestr *uint64 +} + +// newUnifiedPolicies instantiates a new unifiedPolicies object. Channel +// policies can be added to this object. +func newUnifiedPolicies(sourceNode, toNode route.Vertex, + outChanRestr *uint64) *unifiedPolicies { + + return &unifiedPolicies{ + policies: make(map[route.Vertex]*unifiedPolicy), + toNode: toNode, + sourceNode: sourceNode, + outChanRestr: outChanRestr, + } +} + +// addPolicy adds a single channel policy. Capacity may be zero if unknown +// (light clients). +func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, + edge *channeldb.ChannelEdgePolicy, capacity btcutil.Amount) { + + localChan := fromNode == u.sourceNode + + // Skip channels if there is an outgoing channel restriction. + if localChan && u.outChanRestr != nil && + *u.outChanRestr != edge.ChannelID { + + return + } + + // Update the policies map. + policy, ok := u.policies[fromNode] + if !ok { + policy = &unifiedPolicy{ + localChan: localChan, + } + u.policies[fromNode] = policy + } + + policy.edges = append(policy.edges, &unifiedPolicyEdge{ + policy: edge, + capacity: capacity, + }) +} + +// addGraphPolicies adds all policies that are known for the toNode in the +// graph. +func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph, + tx *bbolt.Tx) error { + + cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, _, + inEdge *channeldb.ChannelEdgePolicy) error { + + // If there is no edge policy for this candidate node, skip. + // Note that we are searching backwards so this node would have + // come prior to the pivot node in the route. + if inEdge == nil { + return nil + } + + // The node on the other end of this channel is the from node. + fromNode, err := edgeInfo.OtherNodeKeyBytes(u.toNode[:]) + if err != nil { + return err + } + + // Add this policy to the unified policies map. + u.addPolicy(fromNode, inEdge, edgeInfo.Capacity) + + return nil + } + + // Iterate over all channels of the to node. + return g.ForEachNodeChannel(tx, u.toNode[:], cb) +} + +// unifiedPolicyEdge is the individual channel data that is kept inside an +// unifiedPolicy object. +type unifiedPolicyEdge struct { + policy *channeldb.ChannelEdgePolicy + capacity btcutil.Amount +} + +// amtInRange checks whether an amount falls within the valid range for a +// channel. +func (u *unifiedPolicyEdge) amtInRange(amt lnwire.MilliSatoshi) bool { + // If the capacity is available (non-light clients), skip channels that + // are too small. + if u.capacity > 0 && + amt > lnwire.NewMSatFromSatoshis(u.capacity) { + + return false + } + + // Skip channels for which this htlc is too large. + if u.policy.MessageFlags.HasMaxHtlc() && + amt > u.policy.MaxHTLC { + + return false + } + + // Skip channels for which this htlc is too small. + if amt < u.policy.MinHTLC { + return false + } + + return true +} + +// unifiedPolicy is the unified policy that covers all channels between a pair +// of nodes. +type unifiedPolicy struct { + edges []*unifiedPolicyEdge + localChan bool +} + +// getPolicy returns the optimal policy to use for this connection given a +// specific amount to send. It differentiates between local and network +// channels. +func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + + if u.localChan { + return u.getPolicyLocal(amt, bandwidthHints) + } + + return u.getPolicyNetwork(amt) +} + +// getPolicyLocal returns the optimal policy to use for this local connection +// given a specific amount to send. +func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, + bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + + var ( + bestPolicy *channeldb.ChannelEdgePolicy + maxBandwidth lnwire.MilliSatoshi + ) + + for _, edge := range u.edges { + // Check valid amount range for the channel. + if !edge.amtInRange(amt) { + continue + } + + // For local channels, there is no fee to pay or an extra time + // lock. We only consider the currently available bandwidth for + // channel selection. The disabled flag is ignored for local + // channels. + + // Retrieve bandwidth for this local channel. If not + // available, assume this channel has enough bandwidth. + // + // TODO(joostjager): Possibly change to skipping this + // channel. The bandwidth hint is expected to be + // available. + bandwidth, ok := bandwidthHints[edge.policy.ChannelID] + if !ok { + bandwidth = lnwire.MaxMilliSatoshi + } + + // Skip channels that can't carry the payment. + if amt > bandwidth { + continue + } + + // We pick the local channel with the highest available + // bandwidth, to maximize the success probability. It + // can be that the channel state changes between + // querying the bandwidth hints and sending out the + // htlc. + if bandwidth < maxBandwidth { + continue + } + maxBandwidth = bandwidth + + // Update best policy. + bestPolicy = edge.policy + } + + return bestPolicy +} + +// getPolicyNetwork returns the optimal policy to use for this connection given +// a specific amount to send. The goal is to return a policy that maximizes the +// probability of a successful forward in a non-strict forwarding context. +func (u *unifiedPolicy) getPolicyNetwork( + amt lnwire.MilliSatoshi) *channeldb.ChannelEdgePolicy { + + var ( + bestPolicy *channeldb.ChannelEdgePolicy + maxFee lnwire.MilliSatoshi + maxTimelock uint16 + ) + + for _, edge := range u.edges { + // Check valid amount range for the channel. + if !edge.amtInRange(amt) { + continue + } + + // For network channels, skip the disabled ones. + edgeFlags := edge.policy.ChannelFlags + isDisabled := edgeFlags&lnwire.ChanUpdateDisabled != 0 + if isDisabled { + continue + } + + // Track the maximum time lock of all channels that are + // candidate for non-strict forwarding at the routing node. + if edge.policy.TimeLockDelta > maxTimelock { + maxTimelock = edge.policy.TimeLockDelta + } + + // Use the policy that results in the highest fee for this + // specific amount. + fee := edge.policy.ComputeFee(amt) + if fee < maxFee { + continue + } + maxFee = fee + + bestPolicy = edge.policy + } + + // Return early if no channel matches. + if bestPolicy == nil { + return nil + } + + // We have already picked the highest fee that could be required for + // non-strict forwarding. To also cover the case where a lower fee + // channel requires a longer time lock, we modify the policy by setting + // the maximum encountered time lock. Note that this results in a + // synthetic policy that is not actually present on the routing node. + // + // The reason we do this, is that we try to maximize the chance that we + // get forwarded. Because we penalize pair-wise, there won't be a second + // chance for this node pair. But this is all only needed for nodes that + // have distinct policies for channels to the same peer. + modifiedPolicy := *bestPolicy + modifiedPolicy.TimeLockDelta = maxTimelock + + return &modifiedPolicy +} + +// minAmt returns the minimum amount that can be forwarded on this connection. +func (u *unifiedPolicy) minAmt() lnwire.MilliSatoshi { + min := lnwire.MaxMilliSatoshi + for _, edge := range u.edges { + if edge.policy.MinHTLC < min { + min = edge.policy.MinHTLC + } + } + + return min +} diff --git a/routing/unified_policies_test.go b/routing/unified_policies_test.go new file mode 100644 index 00000000..e89a3cb1 --- /dev/null +++ b/routing/unified_policies_test.go @@ -0,0 +1,91 @@ +package routing + +import ( + "testing" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// TestUnifiedPolicies tests the composition of unified policies for nodes that +// have multiple channels between them. +func TestUnifiedPolicies(t *testing.T) { + source := route.Vertex{1} + toNode := route.Vertex{2} + fromNode := route.Vertex{3} + + bandwidthHints := map[uint64]lnwire.MilliSatoshi{} + + u := newUnifiedPolicies(source, toNode, nil) + + // Add two channels between the pair of nodes. + p1 := channeldb.ChannelEdgePolicy{ + FeeProportionalMillionths: 100000, + FeeBaseMSat: 30, + TimeLockDelta: 60, + MessageFlags: lnwire.ChanUpdateOptionMaxHtlc, + MaxHTLC: 500, + MinHTLC: 100, + } + p2 := channeldb.ChannelEdgePolicy{ + FeeProportionalMillionths: 190000, + FeeBaseMSat: 10, + TimeLockDelta: 40, + MessageFlags: lnwire.ChanUpdateOptionMaxHtlc, + MaxHTLC: 400, + MinHTLC: 100, + } + u.addPolicy(fromNode, &p1, 7) + u.addPolicy(fromNode, &p2, 7) + + checkPolicy := func(policy *channeldb.ChannelEdgePolicy, + feeBase lnwire.MilliSatoshi, feeRate lnwire.MilliSatoshi, + timeLockDelta uint16) { + + t.Helper() + + if policy.FeeBaseMSat != feeBase { + t.Fatalf("expected fee base %v, got %v", + feeBase, policy.FeeBaseMSat) + } + + if policy.TimeLockDelta != timeLockDelta { + t.Fatalf("expected fee base %v, got %v", + timeLockDelta, policy.TimeLockDelta) + } + + if policy.FeeProportionalMillionths != feeRate { + t.Fatalf("expected fee rate %v, got %v", + feeRate, policy.FeeProportionalMillionths) + } + } + + policy := u.policies[fromNode].getPolicy(50, bandwidthHints) + if policy != nil { + t.Fatal("expected no policy for amt below min htlc") + } + + policy = u.policies[fromNode].getPolicy(550, bandwidthHints) + if policy != nil { + t.Fatal("expected no policy for amt above max htlc") + } + + // For 200 sat, p1 yields the highest fee. Use that policy to forward, + // because it will also match p2 in case p1 does not have enough + // balance. + policy = u.policies[fromNode].getPolicy(200, bandwidthHints) + checkPolicy( + policy, p1.FeeBaseMSat, p1.FeeProportionalMillionths, + p1.TimeLockDelta, + ) + + // For 400 sat, p2 yields the highest fee. Use that policy to forward, + // because it will also match p1 in case p2 does not have enough + // balance. In order to match p1, it needs to have p1's time lock delta. + policy = u.policies[fromNode].getPolicy(400, bandwidthHints) + checkPolicy( + policy, p2.FeeBaseMSat, p2.FeeProportionalMillionths, + p1.TimeLockDelta, + ) +}