discovery+routing: validate msg flags and max htlc in ChannelUpdates
In this commit, we alter the ValidateChannelUpdateAnn function in ann_validation to validate a remote ChannelUpdate's message flags and max HTLC field. If the message flag is set but the max HTLC field is not set or vice versa, the ChannelUpdate fails validation. Co-authored-by: Johan T. Halseth <johanth@gmail.com>
This commit is contained in:
parent
f316cc6c7e
commit
15168c391e
@ -1992,10 +1992,11 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(
|
|||||||
pubKey, _ = chanInfo.NodeKey2()
|
pubKey, _ = chanInfo.NodeKey2()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the channel announcement with the expected public
|
// Validate the channel announcement with the expected public key and
|
||||||
// key, In the case of an invalid channel , we'll return an
|
// channel capacity. In the case of an invalid channel update, we'll
|
||||||
// error to the caller and exit early.
|
// return an error to the caller and exit early.
|
||||||
if err := routing.ValidateChannelUpdateAnn(pubKey, msg); err != nil {
|
err = routing.ValidateChannelUpdateAnn(pubKey, chanInfo.Capacity, msg)
|
||||||
|
if err != nil {
|
||||||
rErr := fmt.Errorf("unable to validate channel "+
|
rErr := fmt.Errorf("unable to validate channel "+
|
||||||
"update announcement for short_chan_id=%v: %v",
|
"update announcement for short_chan_id=%v: %v",
|
||||||
spew.Sdump(msg.ShortChannelID), err)
|
spew.Sdump(msg.ShortChannelID), err)
|
||||||
@ -2548,7 +2549,7 @@ func (d *AuthenticatedGossiper) updateChannel(info *channeldb.ChannelEdgeInfo,
|
|||||||
|
|
||||||
// To ensure that our signature is valid, we'll verify it ourself
|
// To ensure that our signature is valid, we'll verify it ourself
|
||||||
// before committing it to the slice returned.
|
// before committing it to the slice returned.
|
||||||
err = routing.ValidateChannelUpdateAnn(d.selfKey, chanUpdate)
|
err = routing.ValidateChannelUpdateAnn(d.selfKey, info.Capacity, chanUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("generated invalid channel "+
|
return nil, nil, fmt.Errorf("generated invalid channel "+
|
||||||
"update sig: %v", err)
|
"update sig: %v", err)
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -474,14 +475,7 @@ func createUpdateAnnouncement(blockHeight uint32,
|
|||||||
a.ExtraOpaqueData = extraBytes[0]
|
a.ExtraOpaqueData = extraBytes[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub := nodeKey.PubKey()
|
err = signUpdate(nodeKey, a)
|
||||||
signer := mockSigner{nodeKey}
|
|
||||||
sig, err := SignAnnouncement(&signer, pub, a)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
a.Signature, err = lnwire.NewSigFromSignature(sig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -489,6 +483,22 @@ func createUpdateAnnouncement(blockHeight uint32,
|
|||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func signUpdate(nodeKey *btcec.PrivateKey, a *lnwire.ChannelUpdate) error {
|
||||||
|
pub := nodeKey.PubKey()
|
||||||
|
signer := mockSigner{nodeKey}
|
||||||
|
sig, err := SignAnnouncement(&signer, pub, a)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.Signature, err = lnwire.NewSigFromSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func createAnnouncementWithoutProof(blockHeight uint32,
|
func createAnnouncementWithoutProof(blockHeight uint32,
|
||||||
extraBytes ...[]byte) *lnwire.ChannelAnnouncement {
|
extraBytes ...[]byte) *lnwire.ChannelAnnouncement {
|
||||||
|
|
||||||
@ -2765,6 +2775,93 @@ func TestNodeAnnouncementNoChannels(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOptionalFieldsChannelUpdateValidation tests that we're able to properly
|
||||||
|
// validate the msg flags and optional max HTLC field of a ChannelUpdate.
|
||||||
|
func TestOptionalFieldsChannelUpdateValidation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cleanup, err := createTestCtx(0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't create context: %v", err)
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
chanUpdateHeight := uint32(0)
|
||||||
|
timestamp := uint32(123456)
|
||||||
|
nodePeer := &mockPeer{nodeKeyPriv1.PubKey(), nil, nil}
|
||||||
|
|
||||||
|
// In this scenario, we'll test whether the message flags field in a channel
|
||||||
|
// update is properly handled.
|
||||||
|
chanAnn, err := createRemoteChannelAnnouncement(chanUpdateHeight)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("can't create channel announcement: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanAnn, nodePeer):
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not process remote announcement")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to process announcement: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The first update should fail from an invalid max HTLC field, which is
|
||||||
|
// less than the min HTLC.
|
||||||
|
chanUpdAnn, err := createUpdateAnnouncement(0, 0, nodeKeyPriv1, timestamp)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to create channel update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chanUpdAnn.HtlcMinimumMsat = 5000
|
||||||
|
chanUpdAnn.HtlcMaximumMsat = 4000
|
||||||
|
if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil {
|
||||||
|
t.Fatalf("unable to sign channel update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer):
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not process remote announcement")
|
||||||
|
}
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "invalid max htlc") {
|
||||||
|
t.Fatalf("expected chan update to error, instead got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The second update should fail because the message flag is set but
|
||||||
|
// the max HTLC field is 0.
|
||||||
|
chanUpdAnn.HtlcMinimumMsat = 0
|
||||||
|
chanUpdAnn.HtlcMaximumMsat = 0
|
||||||
|
if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil {
|
||||||
|
t.Fatalf("unable to sign channel update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer):
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not process remote announcement")
|
||||||
|
}
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "invalid max htlc") {
|
||||||
|
t.Fatalf("expected chan update to error, instead got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The final update should succeed, since setting the flag 0 means the
|
||||||
|
// nonsense max_htlc field will just be ignored.
|
||||||
|
chanUpdAnn.MessageFlags = 0
|
||||||
|
if err := signUpdate(nodeKeyPriv1, chanUpdAnn); err != nil {
|
||||||
|
t.Fatalf("unable to sign channel update: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ctx.gossiper.ProcessRemoteAnnouncement(chanUpdAnn, nodePeer):
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not process remote announcement")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to process announcement: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// mockPeer implements the lnpeer.Peer interface and is used to test the
|
// mockPeer implements the lnpeer.Peer interface and is used to test the
|
||||||
// gossiper's interaction with peers.
|
// gossiper's interaction with peers.
|
||||||
type mockPeer struct {
|
type mockPeer struct {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec"
|
"github.com/btcsuite/btcd/btcec"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
|
"github.com/btcsuite/btcutil"
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
"github.com/lightningnetwork/lnd/lnwire"
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
@ -121,11 +122,16 @@ func ValidateNodeAnn(a *lnwire.NodeAnnouncement) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateChannelUpdateAnn validates the channel update announcement by
|
// ValidateChannelUpdateAnn validates the channel update announcement by
|
||||||
// checking that the included signature covers he announcement and has been
|
// checking (1) that the included signature covers the announcement and has been
|
||||||
// signed by the node's private key.
|
// signed by the node's private key, and (2) that the announcement's message
|
||||||
func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey,
|
// flags and optional fields are sane.
|
||||||
|
func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey, capacity btcutil.Amount,
|
||||||
a *lnwire.ChannelUpdate) error {
|
a *lnwire.ChannelUpdate) error {
|
||||||
|
|
||||||
|
if err := validateOptionalFields(capacity, a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
data, err := a.DataToSign()
|
data, err := a.DataToSign()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Errorf("unable to reconstruct message: %v", err)
|
return errors.Errorf("unable to reconstruct message: %v", err)
|
||||||
@ -144,3 +150,25 @@ func ValidateChannelUpdateAnn(pubKey *btcec.PublicKey,
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateOptionalFields validates a channel update's message flags and
|
||||||
|
// corresponding update fields.
|
||||||
|
func validateOptionalFields(capacity btcutil.Amount,
|
||||||
|
msg *lnwire.ChannelUpdate) error {
|
||||||
|
|
||||||
|
if msg.MessageFlags&lnwire.ChanUpdateOptionMaxHtlc != 0 {
|
||||||
|
maxHtlc := msg.HtlcMaximumMsat
|
||||||
|
if maxHtlc == 0 || maxHtlc < msg.HtlcMinimumMsat {
|
||||||
|
return errors.Errorf("invalid max htlc for channel "+
|
||||||
|
"update %v", spew.Sdump(msg))
|
||||||
|
}
|
||||||
|
cap := lnwire.NewMSatFromSatoshis(capacity)
|
||||||
|
if maxHtlc > cap {
|
||||||
|
return errors.Errorf("max_htlc(%v) for channel "+
|
||||||
|
"update greater than capacity(%v)", maxHtlc,
|
||||||
|
cap)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -2060,12 +2060,18 @@ func (r *ChannelRouter) applyChannelUpdate(msg *lnwire.ChannelUpdate,
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ValidateChannelUpdateAnn(pubKey, msg); err != nil {
|
ch, _, _, err := r.GetChannelByID(msg.ShortChannelID)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Unable to retrieve channel by id: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateChannelUpdateAnn(pubKey, ch.Capacity, msg); err != nil {
|
||||||
log.Errorf("Unable to validate channel update: %v", err)
|
log.Errorf("Unable to validate channel update: %v", err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.UpdateEdge(&channeldb.ChannelEdgePolicy{
|
err = r.UpdateEdge(&channeldb.ChannelEdgePolicy{
|
||||||
SigBytes: msg.Signature.ToSignatureBytes(),
|
SigBytes: msg.Signature.ToSignatureBytes(),
|
||||||
ChannelID: msg.ShortChannelID.ToUint64(),
|
ChannelID: msg.ShortChannelID.ToUint64(),
|
||||||
LastUpdate: time.Unix(int64(msg.Timestamp), 0),
|
LastUpdate: time.Unix(int64(msg.Timestamp), 0),
|
||||||
|
Loading…
Reference in New Issue
Block a user