Merge pull request #3915 from carlaKC/3771-loopattackprotection

htlcswitch: Disallow circular routes on same channel
This commit is contained in:
Olaoluwa Osuntokun 2020-02-03 15:59:53 -08:00 committed by GitHub
commit e25cca11f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 301 additions and 24 deletions

View File

@ -341,6 +341,8 @@ type config struct {
Watchtower *lncfg.Watchtower `group:"watchtower" namespace:"watchtower"`
LegacyProtocol *lncfg.LegacyProtocol `group:"legacyprotocol" namespace:"legacyprotocol"`
AllowCircularRoute bool `long:"allow-circular-route" description:"If true, our node will allow htlc forwards that arrive and depart on the same channel."`
}
// loadConfig initializes and parses the config using a config file and command

View File

@ -29,6 +29,11 @@ const (
// FailureDetailInsufficientBalance is returned when we cannot route a
// htlc due to insufficient outgoing capacity.
FailureDetailInsufficientBalance
// FailureDetailCircularRoute is returned when an attempt is made
// to forward a htlc through our node which arrives and leaves on the
// same channel.
FailureDetailCircularRoute
)
// String returns the string representation of a failure detail.
@ -52,6 +57,9 @@ func (fd FailureDetail) String() string {
case FailureDetailInsufficientBalance:
return "insufficient bandwidth to route htlc"
case FailureDetailCircularRoute:
return "same incoming and outgoing channel"
default:
return "unknown failure detail"
}

View File

@ -39,6 +39,7 @@ import (
const (
testStartingHeight = 100
testDefaultDelta = 6
)
// concurrentTester is a thread-safe wrapper around the Fatalf method of a

View File

@ -167,6 +167,10 @@ type Config struct {
// fails in forwarding packages.
AckEventTicker ticker.Ticker
// AllowCircularRoute is true if the user has configured their node to
// allow forwards that arrive and depart our node over the same channel.
AllowCircularRoute bool
// RejectHTLC is a flag that instructs the htlcswitch to reject any
// HTLCs that are not from the source hop.
RejectHTLC bool
@ -986,6 +990,22 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
return s.handleLocalDispatch(packet)
}
// Before we attempt to find a non-strict forwarding path for
// this htlc, check whether the htlc is being routed over the
// same incoming and outgoing channel. If our node does not
// allow forwards of this nature, we fail the htlc early. This
// check is in place to disallow inefficiently routed htlcs from
// locking up our balance.
linkErr := checkCircularForward(
packet.incomingChanID, packet.outgoingChanID,
s.cfg.AllowCircularRoute, htlc.PaymentHash,
)
if linkErr != nil {
return s.failAddPacket(
packet, linkErr.WireMessage(), linkErr,
)
}
s.indexMtx.RLock()
targetLink, err := s.getLinkByShortID(packet.outgoingChanID)
if err != nil {
@ -1170,6 +1190,37 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error {
}
}
// checkCircularForward checks whether a forward is circular (arrives and
// departs on the same link) and returns a link error if the switch is
// configured to disallow this behaviour.
func checkCircularForward(incoming, outgoing lnwire.ShortChannelID,
allowCircular bool, paymentHash lntypes.Hash) *LinkError {
// If the route is not circular we do not need to perform any further
// checks.
if incoming != outgoing {
return nil
}
// If the incoming and outgoing link are equal, the htlc is part of a
// circular route which may be used to lock up our liquidity. If the
// switch is configured to allow circular routes, log that we are
// allowing the route then return nil.
if allowCircular {
log.Debugf("allowing circular route over link: %v "+
"(payment hash: %x)", incoming, paymentHash)
return nil
}
// If our node disallows circular routes, return a temporary channel
// failure. There is nothing wrong with the policy used by the remote
// node, so we do not include a channel update.
return NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
FailureDetailCircularRoute,
)
}
// failAddPacket encrypts a fail packet back to an add packet's source.
// The ciphertext will be derived from the failure message proivded by context.
// This method returns the failErr if all other steps complete successfully.

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"io/ioutil"
"reflect"
"testing"
"time"
@ -32,7 +33,9 @@ func genPreimage() ([32]byte, error) {
func TestSwitchAddDuplicateLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -90,7 +93,9 @@ func TestSwitchAddDuplicateLink(t *testing.T) {
func TestSwitchHasActiveLink(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -158,7 +163,9 @@ func TestSwitchHasActiveLink(t *testing.T) {
func TestSwitchSendPending(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -253,11 +260,15 @@ func TestSwitchSendPending(t *testing.T) {
func TestSwitchForward(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -358,11 +369,15 @@ func TestSwitchForwardFailAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -549,11 +564,15 @@ func TestSwitchForwardSettleAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -743,11 +762,15 @@ func TestSwitchForwardDropAfterFullAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -906,11 +929,15 @@ func TestSwitchForwardFailAfterHalfAdd(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1064,11 +1091,15 @@ func TestSwitchForwardCircuitPersistence(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1294,6 +1325,171 @@ type multiHopFwdTest struct {
expectedReply lnwire.FailCode
}
// TestCircularForwards tests the allowing/disallowing of circular payments
// through the same channel in the case where the switch is configured to allow
// and disallow same channel circular forwards.
func TestCircularForwards(t *testing.T) {
chanID1, aliceChanID := genID()
preimage := [sha256.Size]byte{1}
hash := fastsha256.Sum256(preimage[:])
tests := []struct {
name string
allowCircularPayment bool
expectedErr error
}{
{
name: "circular payment allowed",
allowCircularPayment: true,
expectedErr: nil,
},
{
name: "circular payment disallowed",
allowCircularPayment: false,
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
FailureDetailCircularRoute,
),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil,
testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v",
err)
}
s, err := initSwitchWithDB(testStartingHeight, nil)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() { _ = s.Stop() }()
// Set the switch to allow or disallow circular routes
// according to the test's requirements.
s.cfg.AllowCircularRoute = test.allowCircularPayment
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
// Create a new packet that loops through alice's link
// in a circle.
obfuscator := NewMockObfuscator()
packet := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
outgoingChanID: aliceChannelLink.ShortChanID(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: hash,
Amount: 1,
},
obfuscator: obfuscator,
}
// Attempt to forward the packet and check for the expected
// error.
err = s.forward(packet)
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
// Ensure that no circuits were opened.
if s.circuits.NumOpen() > 0 {
t.Fatal("do not expect any open circuits")
}
})
}
}
// TestCheckCircularForward tests the error returned by checkCircularForward
// in cases where we allow and disallow same channel circular forwards.
func TestCheckCircularForward(t *testing.T) {
tests := []struct {
name string
// allowCircular determines whether we should allow circular
// forwards.
allowCircular bool
// incomingLink is the link that the htlc arrived on.
incomingLink lnwire.ShortChannelID
// outgoingLink is the link that the htlc forward
// is destined to leave on.
outgoingLink lnwire.ShortChannelID
// expectedErr is the error we expect to be returned.
expectedErr *LinkError
}{
{
name: "not circular, allowed in config",
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "not circular, not allowed in config",
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(321),
expectedErr: nil,
},
{
name: "circular, allowed in config",
allowCircular: true,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: nil,
},
{
name: "circular, not allowed in config",
allowCircular: false,
incomingLink: lnwire.NewShortChanIDFromInt(123),
outgoingLink: lnwire.NewShortChanIDFromInt(123),
expectedErr: NewDetailedLinkError(
lnwire.NewTemporaryChannelFailure(nil),
FailureDetailCircularRoute,
),
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
// Check for a circular forward, the hash passed can
// be nil because it is only used for logging.
err := checkCircularForward(
test.incomingLink, test.outgoingLink,
test.allowCircular, lntypes.Hash{},
)
if !reflect.DeepEqual(err, test.expectedErr) {
t.Fatalf("expected: %v, got: %v",
test.expectedErr, err)
}
})
}
}
// TestSkipIneligibleLinksMultiHopForward tests that if a multi-hop HTLC comes
// along, then we won't attempt to froward it down al ink that isn't yet able
// to forward any HTLC's.
@ -1359,11 +1555,15 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T,
var packet *htlcPacket
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1470,7 +1670,9 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool,
// We'll create a single link for this test, marking it as being unable
// to forward form the get go.
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -1524,11 +1726,15 @@ func testSkipLinkLocalForward(t *testing.T, eligible bool,
func TestSwitchCancel(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1637,11 +1843,15 @@ func TestSwitchAddSamePayment(t *testing.T) {
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(t, "bob", testStartingHeight, nil, 6)
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
@ -1796,7 +2006,9 @@ func TestSwitchAddSamePayment(t *testing.T) {
func TestSwitchSendPayment(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
@ -2334,7 +2546,9 @@ func TestSwitchGetPaymentResult(t *testing.T) {
func TestInvalidFailure(t *testing.T) {
t.Parallel()
alicePeer, err := newMockServer(t, "alice", testStartingHeight, nil, 6)
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}

View File

@ -470,6 +470,7 @@ func newServer(listenAddrs []net.Addr, chanDB *channeldb.DB,
FwdEventTicker: ticker.New(htlcswitch.DefaultFwdEventInterval),
LogEventTicker: ticker.New(htlcswitch.DefaultLogInterval),
AckEventTicker: ticker.New(htlcswitch.DefaultAckInterval),
AllowCircularRoute: cfg.AllowCircularRoute,
RejectHTLC: cfg.RejectHTLC,
}, uint32(currentHeight))
if err != nil {