discovery+routing: validate msg flags and max htlc in ChannelUpdates

In this commit, we alter the ValidateChannelUpdateAnn function in
ann_validation to validate a remote ChannelUpdate's message flags
and max HTLC field. If the message flag is set but the max HTLC
field is not set or vice versa, the ChannelUpdate fails validation.

Co-authored-by: Johan T. Halseth <johanth@gmail.com>
This commit is contained in:
Valentine Wallace 2019-01-12 18:59:43 +01:00 committed by Johan T. Halseth
parent f316cc6c7e
commit 15168c391e
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
4 changed files with 150 additions and 18 deletions

@ -1992,10 +1992,11 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
pubKey, _ = chanInfo.NodeKey2() pubKey, _ = chanInfo.NodeKey2()
} }
// Validate the channel announcement with the expected public // Validate the channel announcement with the expected public key and
// key, In the case of an invalid channel , we'll return an // channel capacity. In the case of an invalid channel update, we'll
// error to the caller and exit early. // return an error to the caller and exit early.
if err := routing.ValidateChannelUpdateAnn(pubKey, msg); err != nil { err = routing.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, msg)
if err != nil {
rErr := fmt.Errorf("unable to validate channel "+ rErr := fmt.Errorf("unable to validate channel "+
"update announcement for short_chan_id=%v: %v", "update announcement for short_chan_id=%v: %v",
spew.Sdump(msg.ShortChannelID), err) spew.Sdump(msg.ShortChannelID), err)
@ -2548,7 +2549,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo,
// To ensure that our signature is valid, we'll verify it ourself // To ensure that our signature is valid, we'll verify it ourself
// before committing it to the slice returned. // before committing it to the slice returned.
err = routing.ValidateChannelUpdateAnn(d.selfKey, chanUpdate) err = routing.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("generated invalid channel "+ return nil, nil, fmt.Errorf("generated invalid channel "+
"update sig: %v", err) "update sig: %v", err)

@ -10,6 +10,7 @@ import (
"net" "net"
"os" "os"
"reflect" "reflect"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -474,14 +475,7 @@ func createUpdateAnnouncement(blockHeight uint32,
a.ExtraOpaqueData = extraBytes[0] a.ExtraOpaqueData = extraBytes[0]
} }
pub := nodeKey.PubKey() err = signUpdate(nodeKey, a)
signer := mockSigner{nodeKey}
sig, err := SignAnnouncement(&signer, pub, a)
if err != nil {
return nil, err
}
a.Signature, err = lnwire.NewSigFromSignature(sig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -489,6 +483,22 @@ func createUpdateAnnouncement(blockHeight uint32,
return a, nil return a, nil
} }
func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error {
pub := nodeKey.PubKey()
signer := mockSigner{nodeKey}
sig, err := SignAnnouncement(&signer, pub, a)
if err != nil {
return err
}
a.Signature, err = lnwire.NewSigFromSignature(sig)
if err != nil {
return err
}
return nil
}
func createAnnouncementWithoutProof(blockHeight uint32, func createAnnouncementWithoutProof(blockHeight uint32,
extraBytes ...[]byte) *lnwire.ChannelAnnouncement { extraBytes ...[]byte) *lnwire.ChannelAnnouncement {
@ -2765,6 +2775,93 @@ func TestNodeAnnouncementNoChannels(t *testing.T) {
} }
} }
// TestOptionalFieldsChannelUpdateValidation tests that we're able to properly
// validate the msg flags and optional max HTLC field of a ChannelUpdate.
func TestOptionalFieldsChannelUpdateValidation(t *testing.T) {
t.Parallel()
ctx, cleanup, err := createTestCtx(0)
if err != nil {
t.Fatalf("can't create context: %v", err)
}
defer cleanup()
chanUpdateHeight := uint32(0)
timestamp := uint32(123456)
nodePeer := &mockPeer{nodeKeyPriv1.PubKey(), nil, nil}
// In this scenario, we'll test whether the message flags field in a channel
// update is properly handled.
chanAnn, err := createRemoteChannelAnnouncement(chanUpdateHeight)
if err != nil {
t.Fatalf("can't create channel announcement: %v", err)
}
select {
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanAnn, nodePeer):
case <-time.After(2 * time.Second):
t.Fatal("did not process remote announcement")
}
if err != nil {
t.Fatalf("unable to process announcement: %v", err)
}
// The first update should fail from an invalid max HTLC field, which is
// less than the min HTLC.
chanUpdAnn, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp)
if err != nil {
t.Fatalf("unable to create channel update: %v", err)
}
chanUpdAnn.HtlcMinimumMsat = 5000
chanUpdAnn.HtlcMaximumMsat = 4000
if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil {
t.Fatalf("unable to sign channel update: %v", err)
}
select {
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer):
case <-time.After(2 * time.Second):
t.Fatal("did not process remote announcement")
}
if err == nil || !strings.Contains(err.Error(), "invalid max htlc") {
t.Fatalf("expected chan update to error, instead got %v", err)
}
// The second update should fail because the message flag is set but
// the max HTLC field is 0.
chanUpdAnn.HtlcMinimumMsat = 0
chanUpdAnn.HtlcMaximumMsat = 0
if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil {
t.Fatalf("unable to sign channel update: %v", err)
}
select {
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer):
case <-time.After(2 * time.Second):
t.Fatal("did not process remote announcement")
}
if err == nil || !strings.Contains(err.Error(), "invalid max htlc") {
t.Fatalf("expected chan update to error, instead got %v", err)
}
// The final update should succeed, since setting the flag 0 means the
// nonsense max_htlc field will just be ignored.
chanUpdAnn.MessageFlags = 0
if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil {
t.Fatalf("unable to sign channel update: %v", err)
}
select {
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer):
case <-time.After(2 * time.Second):
t.Fatal("did not process remote announcement")
}
if err != nil {
t.Fatalf("unable to process announcement: %v", err)
}
}
// mockPeer implements the lnpeer.Peer interface and is used to test the // mockPeer implements the lnpeer.Peer interface and is used to test the
// gossiper's interaction with peers. // gossiper's interaction with peers.
type mockPeer struct { type mockPeer struct {

@ -5,6 +5,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcutil"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -121,11 +122,16 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error {
} }
// ValidateChannelUpdateAnn validates the channel update announcement by // ValidateChannelUpdateAnn validates the channel update announcement by
// checking that the included signature covers he announcement and has been // checking (1) that the included signature covers the announcement and has been
// signed by the node's private key. // signed by the node's private key, and (2) that the announcement's message
func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, // flags and optional fields are sane.
func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount,
a *lnwire.ChannelUpdate) error { a *lnwire.ChannelUpdate) error {
if err := validateOptionalFields(capacity, a); err != nil {
return err
}
data, err := a.DataToSign() data, err := a.DataToSign()
if err != nil { if err != nil {
return errors.Errorf("unable to reconstruct message: %v", err) return errors.Errorf("unable to reconstruct message: %v", err)
@ -144,3 +150,25 @@ func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey,
return nil return nil
} }
// validateOptionalFields validates a channel update's message flags and
// corresponding update fields.
func validateOptionalFields(capacity btcutil.Amount,
msg *lnwire.ChannelUpdate) error {
if msg.MessageFlags&lnwire.ChanUpdateOptionMaxHtlc != 0 {
maxHtlc := msg.HtlcMaximumMsat
if maxHtlc == 0 || maxHtlc < msg.HtlcMinimumMsat {
return errors.Errorf("invalid max htlc for channel "+
"update %v", spew.Sdump(msg))
}
cap := lnwire.NewMSatFromSatoshis(capacity)
if maxHtlc > cap {
return errors.Errorf("max_htlc(%v) for channel "+
"update greater than capacity(%v)", maxHtlc,
cap)
}
}
return nil
}

@ -2060,12 +2060,18 @@ func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate,
return true return true
} }
if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil { ch, _, _, err := r.GetChannelByID(msg.ShortChannelID)
if err != nil {
log.Errorf("Unable to retrieve channel by id: %v", err)
return false
}
if err := ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg); err != nil {
log.Errorf("Unable to validate channel update: %v", err) log.Errorf("Unable to validate channel update: %v", err)
return false return false
} }
err := r.UpdateEdge(&channeldb.ChannelEdgePolicy{ err = r.UpdateEdge(&channeldb.ChannelEdgePolicy{
SigBytes: msg.Signature.ToSignatureBytes(), SigBytes: msg.Signature.ToSignatureBytes(),
ChannelID: msg.ShortChannelID.ToUint64(), ChannelID: msg.ShortChannelID.ToUint64(),
LastUpdate: time.Unix(int64(msg.Timestamp), 0), LastUpdate: time.Unix(int64(msg.Timestamp), 0),