Merge pull request #1749 from cfromknecht/htlcswitch-code-health

[htlcswitch]: improve code health of switch
This commit is contained in:
Olaoluwa Osuntokun 2018-08-19 17:40:26 -07:00 committed by GitHub
commit 21841c9f6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ package htlcswitch
import ( import (
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -14,7 +15,6 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
@ -45,7 +45,11 @@ var (
// ErrIncompleteForward is used when an htlc was already forwarded // ErrIncompleteForward is used when an htlc was already forwarded
// through the switch, but did not get locked into another commitment // through the switch, but did not get locked into another commitment
// txn. // txn.
ErrIncompleteForward = errors.Errorf("incomplete forward detected") ErrIncompleteForward = errors.New("incomplete forward detected")
// ErrSwitchExiting signaled when the switch has received a shutdown
// request.
ErrSwitchExiting = errors.New("htlcswitch shutting down")
// zeroPreimage is the empty preimage which is returned when we have // zeroPreimage is the empty preimage which is returned when we have
// some errors. // some errors.
@ -322,13 +326,13 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
doneChan: done, doneChan: done,
}: }:
case <-s.quit: case <-s.quit:
return fmt.Errorf("switch shutting down") return ErrSwitchExiting
} }
select { select {
case <-done: case <-done:
case <-s.quit: case <-s.quit:
return fmt.Errorf("switch shutting down") return ErrSwitchExiting
} }
return nil return nil
@ -383,24 +387,21 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC,
case e := <-payment.err: case e := <-payment.err:
err = e err = e
case <-s.quit: case <-s.quit:
return zeroPreimage, errors.New("htlc switch have been stopped " + return zeroPreimage, ErrSwitchExiting
"while waiting for payment result")
} }
select { select {
case pkt := <-payment.response: case pkt := <-payment.response:
response = pkt response = pkt
case <-s.quit: case <-s.quit:
return zeroPreimage, errors.New("htlc switch have been stopped " + return zeroPreimage, ErrSwitchExiting
"while waiting for payment result")
} }
select { select {
case p := <-payment.preimage: case p := <-payment.preimage:
preimage = p preimage = p
case <-s.quit: case <-s.quit:
return zeroPreimage, errors.New("htlc switch have been stopped " + return zeroPreimage, ErrSwitchExiting
"while waiting for payment result")
} }
// Remove circuit since we are about to complete an add/fail of this // Remove circuit since we are about to complete an add/fail of this
@ -666,7 +667,7 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{},
} }
for _, packet := range failedPackets { for _, packet := range failedPackets {
addErr := errors.Errorf("failing packet after " + addErr := errors.New("failing packet after " +
"detecting incomplete forward") "detecting incomplete forward")
// We don't handle the error here since this method // We don't handle the error here since this method
@ -689,9 +690,7 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{},
func (s *Switch) proxyFwdErrs(num *int, wg *sync.WaitGroup, func (s *Switch) proxyFwdErrs(num *int, wg *sync.WaitGroup,
fwdChan, errChan chan error) { fwdChan, errChan chan error) {
defer s.wg.Done() defer s.wg.Done()
defer func() { defer close(errChan)
close(errChan)
}()
// Wait here until the outer function has finished persisting // Wait here until the outer function has finished persisting
// and routing the packets. This guarantees we don't read from num until // and routing the packets. This guarantees we don't read from num until
@ -722,14 +721,14 @@ func (s *Switch) route(packet *htlcPacket) error {
select { select {
case s.htlcPlex <- command: case s.htlcPlex <- command:
case <-s.quit: case <-s.quit:
return errors.New("Htlc Switch was stopped") return ErrSwitchExiting
} }
select { select {
case err := <-command.err: case err := <-command.err:
return err return err
case <-s.quit: case <-s.quit:
return errors.New("Htlc Switch was stopped") return ErrSwitchExiting
} }
} }
@ -969,7 +968,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
// than we should notify this link that some error // than we should notify this link that some error
// occurred. // occurred.
failure := &lnwire.FailUnknownNextPeer{} failure := &lnwire.FailUnknownNextPeer{}
addErr := errors.Errorf("unable to find link with "+ addErr := fmt.Errorf("unable to find link with "+
"destination %v", packet.outgoingChanID) "destination %v", packet.outgoingChanID)
return s.failAddPacket(packet, failure, addErr) return s.failAddPacket(packet, failure, addErr)
@ -1039,7 +1038,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
failure = lnwire.NewTemporaryChannelFailure(update) failure = lnwire.NewTemporaryChannelFailure(update)
} }
addErr := errors.Errorf("unable to find appropriate "+ addErr := fmt.Errorf("unable to find appropriate "+
"channel link insufficient capacity, need "+ "channel link insufficient capacity, need "+
"%v", htlc.Amount) "%v", htlc.Amount)
@ -1104,7 +1103,7 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
failure, failure,
) )
if err != nil { if err != nil {
err = errors.Errorf("unable to obfuscate "+ err = fmt.Errorf("unable to obfuscate "+
"error: %v", err) "error: %v", err)
log.Error(err) log.Error(err)
} }
@ -1166,7 +1165,7 @@ func (s *Switch) failAddPacket(packet *htlcPacket,
// obfuscate the failure for their eyes only. // obfuscate the failure for their eyes only.
reason, err := packet.obfuscator.EncryptFirstHop(failure) reason, err := packet.obfuscator.EncryptFirstHop(failure)
if err != nil { if err != nil {
err := errors.Errorf("unable to obfuscate "+ err := fmt.Errorf("unable to obfuscate "+
"error: %v", err) "error: %v", err)
log.Error(err) log.Error(err)
return err return err
@ -1186,7 +1185,7 @@ func (s *Switch) failAddPacket(packet *htlcPacket,
// Route a fail packet back to the source link. // Route a fail packet back to the source link.
err = s.mailOrchestrator.Deliver(failPkt.incomingChanID, failPkt) err = s.mailOrchestrator.Deliver(failPkt.incomingChanID, failPkt)
if err != nil { if err != nil {
err = errors.Errorf("source chanid=%v unable to "+ err = fmt.Errorf("source chanid=%v unable to "+
"handle switch packet: %v", "handle switch packet: %v",
packet.incomingChanID, err) packet.incomingChanID, err)
log.Error(err) log.Error(err)
@ -1263,7 +1262,7 @@ func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) {
// Failed to close circuit because it does not exist. This is likely // Failed to close circuit because it does not exist. This is likely
// because the circuit was already successfully closed. // because the circuit was already successfully closed.
case ErrUnknownCircuit: case ErrUnknownCircuit:
err := errors.Errorf("Unable to find target channel "+ err := fmt.Errorf("Unable to find target channel "+
"for HTLC settle/fail: channel ID = %s, "+ "for HTLC settle/fail: channel ID = %s, "+
"HTLC ID = %d", pkt.outgoingChanID, "HTLC ID = %d", pkt.outgoingChanID,
pkt.outgoingHTLCID) pkt.outgoingHTLCID)
@ -1374,8 +1373,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, closeType ChannelCloseType,
return updateChan, errChan return updateChan, errChan
case <-s.quit: case <-s.quit:
errChan <- errors.New("unable close channel link, htlc " + errChan <- ErrSwitchExiting
"switch already stopped")
close(updateChan) close(updateChan)
return updateChan, errChan return updateChan, errChan
} }
@ -1469,7 +1467,7 @@ out:
if !ok { if !ok {
s.indexMtx.RUnlock() s.indexMtx.RUnlock()
req.Err <- errors.Errorf("no peer for channel with "+ req.Err <- fmt.Errorf("no peer for channel with "+
"chan_id=%x", chanID[:]) "chan_id=%x", chanID[:])
continue continue
} }
@ -2021,7 +2019,7 @@ func (s *Switch) GetLinksByInterface(hop [33]byte) ([]ChannelLink, error) {
func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) { func (s *Switch) getLinks(destination [33]byte) ([]ChannelLink, error) {
links, ok := s.interfaceIndex[destination] links, ok := s.interfaceIndex[destination]
if !ok { if !ok {
return nil, errors.Errorf("unable to locate channel link by "+ return nil, fmt.Errorf("unable to locate channel link by "+
"destination hop id %x", destination) "destination hop id %x", destination)
} }
@ -2040,7 +2038,7 @@ func (s *Switch) removePendingPayment(paymentID uint64) error {
defer s.pendingMutex.Unlock() defer s.pendingMutex.Unlock()
if _, ok := s.pendingPayments[paymentID]; !ok { if _, ok := s.pendingPayments[paymentID]; !ok {
return errors.Errorf("Cannot find pending payment with ID %d", return fmt.Errorf("Cannot find pending payment with ID %d",
paymentID) paymentID)
} }
@ -2055,7 +2053,7 @@ func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) {
payment, ok := s.pendingPayments[paymentID] payment, ok := s.pendingPayments[paymentID]
if !ok { if !ok {
return nil, errors.Errorf("Cannot find pending payment with ID %d", return nil, fmt.Errorf("Cannot find pending payment with ID %d",
paymentID) paymentID)
} }
return payment, nil return payment, nil