diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 28a0bb04..1e7db9fe 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "math/rand" "sync" "sync/atomic" "time" @@ -1010,9 +1011,9 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // trying all links to utilize our available bandwidth. linkErrs := make(map[lnwire.ShortChannelID]*LinkError) - // Try to find destination channel link with appropriate + // Find all destination channel links with appropriate // bandwidth. - var destination ChannelLink + var destinations []ChannelLink for _, link := range interfaceLinks { var failure *LinkError @@ -1035,10 +1036,11 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { ) } - // Stop searching if this link can forward the htlc. + // If this link can forward the htlc, add it to the set + // of destinations. if failure == nil { - destination = link - break + destinations = append(destinations, link) + continue } linkErrs[link.ShortChanID()] = failure @@ -1048,7 +1050,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // satisfying the current policy, then we'll send back an // error, but ensure we send back the error sourced at the // *target* link. - if destination == nil { + if len(destinations) == 0 { // At this point, some or all of the links rejected the // HTLC so we couldn't forward it. So we'll try to look // up the error that came from the source. @@ -1075,6 +1077,12 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return s.failAddPacket(packet, linkErr) } + // Choose a random link out of the set of links that can forward + // this htlc. The reason for randomization is to evenly + // distribute the htlc load without making assumptions about + // what the best channel is. + destination := destinations[rand.Intn(len(destinations))] + // Send the packet to the destination channel link which // manages the channel. packet.outgoingChanID = destination.ShortChanID()