diff --git a/netann/channel_update_test.go b/netann/channel_update_test.go new file mode 100644 index 00000000..4ed019a8 --- /dev/null +++ b/netann/channel_update_test.go @@ -0,0 +1,189 @@ +package netann_test + +import ( + "errors" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/netann" + "github.com/lightningnetwork/lnd/routing" +) + +type mockSigner struct { + err error +} + +func (m *mockSigner) SignMessage(pk *btcec.PublicKey, + msg []byte) (*btcec.Signature, error) { + + if m.err != nil { + return nil, m.err + } + + return nil, nil +} + +var _ lnwallet.MessageSigner = (*mockSigner)(nil) + +var ( + privKey, _ = btcec.NewPrivateKey(btcec.S256()) + + pubKey = privKey.PubKey() + + errFailedToSign = errors.New("unable to sign message") +) + +type updateDisableTest struct { + name string + startEnabled bool + disable bool + startTime time.Time + signer lnwallet.MessageSigner + expErr error +} + +var updateDisableTests = []updateDisableTest{ + { + name: "working signer enabled to disabled", + startEnabled: true, + disable: true, + startTime: time.Now(), + signer: netann.NewNodeSigner(privKey), + }, + { + name: "working signer enabled to enabled", + startEnabled: true, + disable: false, + startTime: time.Now(), + signer: netann.NewNodeSigner(privKey), + }, + { + name: "working signer disabled to enabled", + startEnabled: false, + disable: false, + startTime: time.Now(), + signer: netann.NewNodeSigner(privKey), + }, + { + name: "working signer disabled to disabled", + startEnabled: false, + disable: true, + startTime: time.Now(), + signer: netann.NewNodeSigner(privKey), + }, + { + name: "working signer future monotonicity", + startEnabled: true, + disable: true, + startTime: time.Now().Add(time.Hour), // must increment + signer: netann.NewNodeSigner(privKey), + }, + { + name: "failing signer", + startTime: time.Now(), + signer: &mockSigner{err: errFailedToSign}, + expErr: errFailedToSign, + }, + { + name: "invalid sig from signer", + startTime: time.Now(), + signer: &mockSigner{}, // returns a nil signature + expErr: errors.New("cannot decode empty signature"), + }, +} + +// TestUpdateDisableFlag checks the behavior of UpdateDisableFlag, asserting +// that the proper channel flags are set, the timestamp always increases +// monotonically, and that the correct errors are returned in the event that the +// signer is unable to produce a signature. +func TestUpdateDisableFlag(t *testing.T) { + t.Parallel() + + for _, tc := range updateDisableTests { + t.Run(tc.name, func(t *testing.T) { + // Create the initial update, the only fields we are + // concerned with in this test are the timestamp and the + // channel flags. + ogUpdate := &lnwire.ChannelUpdate{ + Timestamp: uint32(tc.startTime.Unix()), + } + if !tc.startEnabled { + ogUpdate.ChannelFlags |= lnwire.ChanUpdateDisabled + } + + // Create new update to sign using the same fields as + // the original. UpdateDisableFlag will mutate the + // passed channel update, so we keep the old one to test + // against. + newUpdate := &lnwire.ChannelUpdate{ + Timestamp: ogUpdate.Timestamp, + ChannelFlags: ogUpdate.ChannelFlags, + } + + // Attempt to update and sign the new update, specifying + // disabled or enabled as prescribed in the test case. + err := netann.SignChannelUpdate( + tc.signer, pubKey, newUpdate, + netann.ChannelUpdateSetDisable(tc.disable), + ) + + var fail bool + switch { + + // Both nil, pass. + case tc.expErr == nil && err == nil: + + // Both non-nil, compare error strings since some + // methods don't return concrete error types. + case tc.expErr != nil && err != nil: + if err.Error() != tc.expErr.Error() { + fail = true + } + + // Otherwise, one is nil and one is non-nil. + default: + fail = true + } + + if fail { + t.Fatalf("expected error: %v, got %v", + tc.expErr, err) + } + + // Exit early if the test expected a failure. + if tc.expErr != nil { + return + } + + // Verify that the timestamp has increased from the + // original update. + if newUpdate.Timestamp <= ogUpdate.Timestamp { + t.Fatalf("update timestamp should be "+ + "monotonically increasing, "+ + "original: %d, new %d", + ogUpdate.Timestamp, newUpdate.Timestamp) + } + + // Verify that the disabled flag is properly set. + disabled := newUpdate.ChannelFlags& + lnwire.ChanUpdateDisabled != 0 + if disabled != tc.disable { + t.Fatalf("expected disable:%v, found:%v", + tc.disable, disabled) + } + + // Finally, validate the signature using the router's + // verification logic. + err = routing.ValidateChannelUpdateAnn( + pubKey, 0, newUpdate, + ) + if err != nil { + t.Fatalf("channel update failed to "+ + "validate: %v", err) + } + }) + } +}