diff --git a/channeldb/graph.go b/channeldb/graph.go index 3a8f6c56..8a41e78b 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2832,6 +2832,22 @@ func (c *ChannelEdgePolicy) ComputeFee( return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts } +// divideCeil divides dividend by factor and rounds the result up. +func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi { + return (dividend + factor - 1) / factor +} + +// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming +// amount. +func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( + incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return incomingAmt - divideCeil( + feeRateParts*(incomingAmt-c.FeeBaseMSat), + feeRateParts+c.FeeProportionalMillionths, + ) +} + // FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for // the channel identified by the funding outpoint. If the channel can't be // found, then ErrEdgeNotFound is returned. A struct which houses the general diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index 5068228f..de8774a9 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -3173,3 +3173,25 @@ func TestLightningNodeSigVerification(t *testing.T) { t.Fatalf("unable to verify sig") } } + +// TestComputeFee tests fee calculation based on both in- and outgoing amt. +func TestComputeFee(t *testing.T) { + var ( + policy = ChannelEdgePolicy{ + FeeBaseMSat: 10000, + FeeProportionalMillionths: 30000, + } + outgoingAmt = lnwire.MilliSatoshi(1000000) + expectedFee = lnwire.MilliSatoshi(40000) + ) + + fee := policy.ComputeFee(outgoingAmt) + if fee != expectedFee { + t.Fatalf("expected fee %v, got %v", expectedFee, fee) + } + + fwdFee := policy.ComputeFeeFromIncoming(outgoingAmt + fee) + if fwdFee != expectedFee { + t.Fatalf("expected fee %v, but got %v", fee, fwdFee) + } +}