diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 76afb895..b5f355dc 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -1992,10 +1992,11 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement( pubKey, _ = chanInfo.NodeKey2() } - // Validate the channel announcement with the expected public - // key, In the case of an invalid channel , we'll return an - // error to the caller and exit early. - if err := routing.ValidateChannelUpdateAnn(pubKey, msg); err != nil { + // Validate the channel announcement with the expected public key and + // channel capacity. In the case of an invalid channel update, we'll + // return an error to the caller and exit early. + err = routing.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, msg) + if err != nil { rErr := fmt.Errorf("unable to validate channel "+ "update announcement for short_chan_id=%v: %v", spew.Sdump(msg.ShortChannelID), err) @@ -2548,7 +2549,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo, // To ensure that our signature is valid, we'll verify it ourself // before committing it to the slice returned. - err = routing.ValidateChannelUpdateAnn(d.selfKey, chanUpdate) + err = routing.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate) if err != nil { return nil, nil, fmt.Errorf("generated invalid channel "+ "update sig: %v", err) diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index fca3c254..be029878 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -10,6 +10,7 @@ import ( "net" "os" "reflect" + "strings" "sync" "testing" "time" @@ -474,14 +475,7 @@ func createUpdateAnnouncement(blockHeight uint32, a.ExtraOpaqueData = extraBytes[0] } - pub := nodeKey.PubKey() - signer := mockSigner{nodeKey} - sig, err := SignAnnouncement(&signer, pub, a) - if err != nil { - return nil, err - } - - a.Signature, err = lnwire.NewSigFromSignature(sig) + err = signUpdate(nodeKey, a) if err != nil { return nil, err } @@ -489,6 +483,22 @@ func createUpdateAnnouncement(blockHeight uint32, return a, nil } +func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error { + pub := nodeKey.PubKey() + signer := mockSigner{nodeKey} + sig, err := SignAnnouncement(&signer, pub, a) + if err != nil { + return err + } + + a.Signature, err = lnwire.NewSigFromSignature(sig) + if err != nil { + return err + } + + return nil +} + func createAnnouncementWithoutProof(blockHeight uint32, extraBytes ...[]byte) *lnwire.ChannelAnnouncement { @@ -2765,6 +2775,93 @@ func TestNodeAnnouncementNoChannels(t *testing.T) { } } +// TestOptionalFieldsChannelUpdateValidation tests that we're able to properly +// validate the msg flags and optional max HTLC field of a ChannelUpdate. +func TestOptionalFieldsChannelUpdateValidation(t *testing.T) { + t.Parallel() + + ctx, cleanup, err := createTestCtx(0) + if err != nil { + t.Fatalf("can't create context: %v", err) + } + defer cleanup() + + chanUpdateHeight := uint32(0) + timestamp := uint32(123456) + nodePeer := &mockPeer{nodeKeyPriv1.PubKey(), nil, nil} + + // In this scenario, we'll test whether the message flags field in a channel + // update is properly handled. + chanAnn, err := createRemoteChannelAnnouncement(chanUpdateHeight) + if err != nil { + t.Fatalf("can't create channel announcement: %v", err) + } + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanAnn, nodePeer): + case <-time.After(2 * time.Second): + t.Fatal("did not process remote announcement") + } + if err != nil { + t.Fatalf("unable to process announcement: %v", err) + } + + // The first update should fail from an invalid max HTLC field, which is + // less than the min HTLC. + chanUpdAnn, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp) + if err != nil { + t.Fatalf("unable to create channel update: %v", err) + } + + chanUpdAnn.HtlcMinimumMsat = 5000 + chanUpdAnn.HtlcMaximumMsat = 4000 + if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil { + t.Fatalf("unable to sign channel update: %v", err) + } + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer): + case <-time.After(2 * time.Second): + t.Fatal("did not process remote announcement") + } + if err == nil || !strings.Contains(err.Error(), "invalid max htlc") { + t.Fatalf("expected chan update to error, instead got %v", err) + } + + // The second update should fail because the message flag is set but + // the max HTLC field is 0. + chanUpdAnn.HtlcMinimumMsat = 0 + chanUpdAnn.HtlcMaximumMsat = 0 + if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil { + t.Fatalf("unable to sign channel update: %v", err) + } + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer): + case <-time.After(2 * time.Second): + t.Fatal("did not process remote announcement") + } + if err == nil || !strings.Contains(err.Error(), "invalid max htlc") { + t.Fatalf("expected chan update to error, instead got %v", err) + } + + // The final update should succeed, since setting the flag 0 means the + // nonsense max_htlc field will just be ignored. + chanUpdAnn.MessageFlags = 0 + if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil { + t.Fatalf("unable to sign channel update: %v", err) + } + + select { + case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer): + case <-time.After(2 * time.Second): + t.Fatal("did not process remote announcement") + } + if err != nil { + t.Fatalf("unable to process announcement: %v", err) + } +} + // mockPeer implements the lnpeer.Peer interface and is used to test the // gossiper's interaction with peers. type mockPeer struct { diff --git a/routing/ann_validation.go b/routing/ann_validation.go index 257ac3fd..5c70dd86 100644 --- a/routing/ann_validation.go +++ b/routing/ann_validation.go @@ -5,6 +5,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcutil" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/lnwire" @@ -121,11 +122,16 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error { } // ValidateChannelUpdateAnn validates the channel update announcement by -// checking that the included signature covers he announcement and has been -// signed by the node's private key. -func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, +// checking (1) that the included signature covers the announcement and has been +// signed by the node's private key, and (2) that the announcement's message +// flags and optional fields are sane. +func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount, a *lnwire.ChannelUpdate) error { + if err := validateOptionalFields(capacity, a); err != nil { + return err + } + data, err := a.DataToSign() if err != nil { return errors.Errorf("unable to reconstruct message: %v", err) @@ -144,3 +150,25 @@ func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, return nil } + +// validateOptionalFields validates a channel update's message flags and +// corresponding update fields. +func validateOptionalFields(capacity btcutil.Amount, + msg *lnwire.ChannelUpdate) error { + + if msg.MessageFlags&lnwire.ChanUpdateOptionMaxHtlc != 0 { + maxHtlc := msg.HtlcMaximumMsat + if maxHtlc == 0 || maxHtlc < msg.HtlcMinimumMsat { + return errors.Errorf("invalid max htlc for channel "+ + "update %v", spew.Sdump(msg)) + } + cap := lnwire.NewMSatFromSatoshis(capacity) + if maxHtlc > cap { + return errors.Errorf("max_htlc(%v) for channel "+ + "update greater than capacity(%v)", maxHtlc, + cap) + } + } + + return nil +} diff --git a/routing/router.go b/routing/router.go index 35efb8d6..49038842 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2060,12 +2060,18 @@ func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate, return true } - if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil { + ch, _, _, err := r.GetChannelByID(msg.ShortChannelID) + if err != nil { + log.Errorf("Unable to retrieve channel by id: %v", err) + return false + } + + if err := ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg); err != nil { log.Errorf("Unable to validate channel update: %v", err) return false } - err := r.UpdateEdge(&channeldb.ChannelEdgePolicy{ + err = r.UpdateEdge(&channeldb.ChannelEdgePolicy{ SigBytes: msg.Signature.ToSignatureBytes(), ChannelID: msg.ShortChannelID.ToUint64(), LastUpdate: time.Unix(int64(msg.Timestamp), 0),