diff --git a/watchtower/wtwire/init_test.go b/watchtower/wtwire/init_test.go new file mode 100644 index 00000000..337c1de2 --- /dev/null +++ b/watchtower/wtwire/init_test.go @@ -0,0 +1,106 @@ +package wtwire_test + +import ( + "testing" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/watchtower/wtwire" +) + +var ( + testnetChainHash = *chaincfg.TestNet3Params.GenesisHash + mainnetChainHash = *chaincfg.MainNetParams.GenesisHash +) + +type checkRemoteInitTest struct { + name string + lFeatures *lnwire.RawFeatureVector + lHash chainhash.Hash + rFeatures *lnwire.RawFeatureVector + rHash chainhash.Hash + expErr error +} + +var checkRemoteInitTests = []checkRemoteInitTest{ + { + name: "same chain, local-optional remote-required", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rHash: testnetChainHash, + }, + { + name: "same chain, local-required remote-optional", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + rHash: testnetChainHash, + }, + { + name: "different chain, local-optional remote-required", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rHash: mainnetChainHash, + expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), + }, + { + name: "different chain, local-required remote-optional", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsRequired), + rHash: mainnetChainHash, + expErr: wtwire.NewErrUnknownChainHash(mainnetChainHash), + }, + { + name: "same chain, remote-unknown-required", + lFeatures: lnwire.NewRawFeatureVector(wtwire.WtSessionsOptional), + lHash: testnetChainHash, + rFeatures: lnwire.NewRawFeatureVector(lnwire.GossipQueriesRequired), + rHash: testnetChainHash, + expErr: wtwire.NewErrUnknownRequiredFeatures( + lnwire.GossipQueriesRequired, + ), + }, +} + +// TestCheckRemoteInit asserts the behavior of CheckRemoteInit when called with +// the remote party's Init message and the default wtwire.Features. We assert +// the validity of advertised features from the perspective of both client and +// server, as well as failure cases such as differing chain hashes or unknown +// required features. +func TestCheckRemoteInit(t *testing.T) { + for _, test := range checkRemoteInitTests { + t.Run(test.name, func(t *testing.T) { + testCheckRemoteInit(t, test) + }) + } +} + +func testCheckRemoteInit(t *testing.T, test checkRemoteInitTest) { + localInit := wtwire.NewInitMessage(test.lFeatures, test.lHash) + remoteInit := wtwire.NewInitMessage(test.rFeatures, test.rHash) + + err := localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) + switch { + + // Both non-nil, pass. + case err == nil && test.expErr == nil: + return + + // One is nil and one is non-nil, fail. + default: + t.Fatalf("error mismatch, want: %v, got: %v", test.expErr, err) + + // Both non-nil, assert same error type. + case err != nil && test.expErr != nil: + } + + // Compare error strings to assert same type. + if err.Error() != test.expErr.Error() { + t.Fatalf("error mismatch, want: %v, got: %v", + test.expErr.Error(), err.Error()) + } +}