diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go new file mode 100644 index 00000000..60c86636 --- /dev/null +++ b/netann/chan_status_manager_test.go @@ -0,0 +1,816 @@ +package netann_test + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/netann" +) + +// randOutpoint creates a random wire.Outpoint. +func randOutpoint(t *testing.T) wire.OutPoint { + t.Helper() + + var buf [36]byte + _, err := io.ReadFull(rand.Reader, buf[:]) + if err != nil { + t.Fatalf("unable to generate random outpoint: %v", err) + } + + op := wire.OutPoint{} + copy(op.Hash[:], buf[:32]) + op.Index = binary.BigEndian.Uint32(buf[32:]) + + return op +} + +var shortChanIDs uint64 + +// createChannel generates a channeldb.OpenChannel with a random chanpoint and +// short channel id. +func createChannel(t *testing.T) *channeldb.OpenChannel { + t.Helper() + + sid := atomic.AddUint64(&shortChanIDs, 1) + + return &channeldb.OpenChannel{ + ShortChannelID: lnwire.NewShortChanIDFromInt(sid), + ChannelFlags: lnwire.FFAnnounceChannel, + FundingOutpoint: randOutpoint(t), + } +} + +// createEdgePolicies generates an edge info and two directional edge policies. +// The remote party's public key is generated randomly, and then sorted against +// our `pubkey` with the direction bit set appropriately in the policies. Our +// update will be created with the disabled bit set if startEnabled is false. +func createEdgePolicies(t *testing.T, channel *channeldb.OpenChannel, + pubkey *btcec.PublicKey, startEnabled bool) (*channeldb.ChannelEdgeInfo, + *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy) { + + var ( + pubkey1 [33]byte + pubkey2 [33]byte + dir1 lnwire.ChanUpdateChanFlags + dir2 lnwire.ChanUpdateChanFlags + ) + + // Set pubkey1 to OUR pubkey. + copy(pubkey1[:], pubkey.SerializeCompressed()) + + // Set the disabled bit appropriately on our update. + if !startEnabled { + dir1 |= lnwire.ChanUpdateDisabled + } + + // Generate and set pubkey2 for THEIR pubkey. + privKey2, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate key pair: %v", err) + } + copy(pubkey2[:], privKey2.PubKey().SerializeCompressed()) + + // Set pubkey1 to the lower of the two pubkeys. + if bytes.Compare(pubkey2[:], pubkey1[:]) < 0 { + pubkey1, pubkey2 = pubkey2, pubkey1 + dir1, dir2 = dir2, dir1 + } + + // Now that the ordering has been established, set pubkey2's direction + // bit. + dir2 |= lnwire.ChanUpdateDirection + + return &channeldb.ChannelEdgeInfo{ + ChannelPoint: channel.FundingOutpoint, + NodeKey1Bytes: pubkey1, + NodeKey2Bytes: pubkey2, + }, + &channeldb.ChannelEdgePolicy{ + ChannelID: channel.ShortChanID().ToUint64(), + ChannelFlags: dir1, + LastUpdate: time.Now(), + SigBytes: make([]byte, 64), + }, + &channeldb.ChannelEdgePolicy{ + ChannelID: channel.ShortChanID().ToUint64(), + ChannelFlags: dir2, + LastUpdate: time.Now(), + SigBytes: make([]byte, 64), + } +} + +type mockGraph struct { + pubKey *btcec.PublicKey + mu sync.Mutex + channels []*channeldb.OpenChannel + chanInfos map[wire.OutPoint]*channeldb.ChannelEdgeInfo + chanPols1 map[wire.OutPoint]*channeldb.ChannelEdgePolicy + chanPols2 map[wire.OutPoint]*channeldb.ChannelEdgePolicy + sidToCid map[lnwire.ShortChannelID]wire.OutPoint + + updates chan *lnwire.ChannelUpdate +} + +func newMockGraph(t *testing.T, numChannels int, + startActive, startEnabled bool, pubKey *btcec.PublicKey) *mockGraph { + + g := &mockGraph{ + channels: make([]*channeldb.OpenChannel, 0, numChannels), + chanInfos: make(map[wire.OutPoint]*channeldb.ChannelEdgeInfo), + chanPols1: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), + chanPols2: make(map[wire.OutPoint]*channeldb.ChannelEdgePolicy), + sidToCid: make(map[lnwire.ShortChannelID]wire.OutPoint), + updates: make(chan *lnwire.ChannelUpdate, 2*numChannels), + } + + for i := 0; i < numChannels; i++ { + c := createChannel(t) + + info, pol1, pol2 := createEdgePolicies( + t, c, pubKey, startEnabled, + ) + + g.addChannel(c) + g.addEdgePolicy(c, info, pol1, pol2) + } + + return g +} + +func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { + return g.chans(), nil +} + +func (g *mockGraph) FetchChannelEdgesByOutpoint( + op *wire.OutPoint) (*channeldb.ChannelEdgeInfo, + *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { + + g.mu.Lock() + defer g.mu.Unlock() + + info, ok := g.chanInfos[*op] + if !ok { + return nil, nil, nil, channeldb.ErrEdgeNotFound + } + + pol1 := g.chanPols1[*op] + pol2 := g.chanPols2[*op] + + return info, pol1, pol2, nil +} + +func (g *mockGraph) ApplyChannelUpdate(update *lnwire.ChannelUpdate) error { + g.mu.Lock() + defer g.mu.Unlock() + + outpoint, ok := g.sidToCid[update.ShortChannelID] + if !ok { + return fmt.Errorf("unknown short channel id: %v", + update.ShortChannelID) + } + + pol1 := g.chanPols1[outpoint] + pol2 := g.chanPols2[outpoint] + + // Determine which policy we should update by making the flags on the + // policies and updates, and seeing which match up. + var update1 bool + switch { + case update.ChannelFlags&lnwire.ChanUpdateDirection == + pol1.ChannelFlags&lnwire.ChanUpdateDirection: + update1 = true + + case update.ChannelFlags&lnwire.ChanUpdateDirection == + pol2.ChannelFlags&lnwire.ChanUpdateDirection: + update1 = false + + default: + return fmt.Errorf("unable to find policy to update") + } + + timestamp := time.Unix(int64(update.Timestamp), 0) + + policy := &channeldb.ChannelEdgePolicy{ + ChannelID: update.ShortChannelID.ToUint64(), + ChannelFlags: update.ChannelFlags, + LastUpdate: timestamp, + SigBytes: make([]byte, 64), + } + + if update1 { + g.chanPols1[outpoint] = policy + } else { + g.chanPols2[outpoint] = policy + } + + // Send the update to network. This channel should be sufficiently + // buffered to avoid deadlocking. + g.updates <- update + + return nil +} + +func (g *mockGraph) chans() []*channeldb.OpenChannel { + g.mu.Lock() + defer g.mu.Unlock() + + channels := make([]*channeldb.OpenChannel, 0, len(g.channels)) + for _, channel := range g.channels { + channels = append(channels, channel) + } + + return channels +} + +func (g *mockGraph) addChannel(channel *channeldb.OpenChannel) { + g.mu.Lock() + defer g.mu.Unlock() + + g.channels = append(g.channels, channel) +} + +func (g *mockGraph) addEdgePolicy(c *channeldb.OpenChannel, + info *channeldb.ChannelEdgeInfo, + pol1, pol2 *channeldb.ChannelEdgePolicy) { + + g.mu.Lock() + defer g.mu.Unlock() + + g.chanInfos[c.FundingOutpoint] = info + g.chanPols1[c.FundingOutpoint] = pol1 + g.chanPols2[c.FundingOutpoint] = pol2 + g.sidToCid[c.ShortChanID()] = c.FundingOutpoint +} + +func (g *mockGraph) removeChannel(channel *channeldb.OpenChannel) { + g.mu.Lock() + defer g.mu.Unlock() + + for i, c := range g.channels { + if c.FundingOutpoint != channel.FundingOutpoint { + continue + } + + g.channels = append(g.channels[:i], g.channels[i+1:]...) + delete(g.chanInfos, c.FundingOutpoint) + delete(g.chanPols1, c.FundingOutpoint) + delete(g.chanPols2, c.FundingOutpoint) + delete(g.sidToCid, c.ShortChanID()) + return + } +} + +type mockSwitch struct { + mu sync.Mutex + isActive map[lnwire.ChannelID]bool +} + +func newMockSwitch() *mockSwitch { + return &mockSwitch{ + isActive: make(map[lnwire.ChannelID]bool), + } +} + +func (s *mockSwitch) HasActiveLink(chanID lnwire.ChannelID) bool { + s.mu.Lock() + defer s.mu.Unlock() + + // If the link is found, we will returns it's active status. In the + // real switch, it returns EligibleToForward(). + active, ok := s.isActive[chanID] + if ok { + return active + } + + return false +} + +func (s *mockSwitch) SetStatus(chanID lnwire.ChannelID, active bool) { + s.mu.Lock() + defer s.mu.Unlock() + + s.isActive[chanID] = active +} + +func newManagerCfg(t *testing.T, numChannels int, + startEnabled bool) (*netann.ChanStatusConfig, *mockGraph, *mockSwitch) { + + t.Helper() + + privKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate key pair: %v", err) + } + + graph := newMockGraph( + t, numChannels, startEnabled, startEnabled, privKey.PubKey(), + ) + htlcSwitch := newMockSwitch() + + cfg := &netann.ChanStatusConfig{ + ChanStatusSampleInterval: 50 * time.Millisecond, + ChanEnableTimeout: 500 * time.Millisecond, + ChanDisableTimeout: time.Second, + OurPubKey: privKey.PubKey(), + MessageSigner: netann.NewNodeSigner(privKey), + IsChannelActive: htlcSwitch.HasActiveLink, + ApplyChannelUpdate: graph.ApplyChannelUpdate, + DB: graph, + Graph: graph, + } + + return cfg, graph, htlcSwitch +} + +type testHarness struct { + t *testing.T + numChannels int + graph *mockGraph + htlcSwitch *mockSwitch + mgr *netann.ChanStatusManager + ourPubKey *btcec.PublicKey + safeDisableTimeout time.Duration +} + +// newHarness returns a new testHarness for testing a ChanStatusManager. The +// mockGraph will be populated with numChannels channels. The startActive and +// startEnabled govern the initial state of the channels wrt the htlcswitch and +// the network, respectively. +func newHarness(t *testing.T, numChannels int, + startActive, startEnabled bool) testHarness { + + cfg, graph, htlcSwitch := newManagerCfg(t, numChannels, startEnabled) + + mgr, err := netann.NewChanStatusManager(cfg) + if err != nil { + t.Fatalf("unable to create chan status manager: %v", err) + } + + err = mgr.Start() + if err != nil { + t.Fatalf("unable to start chan status manager: %v", err) + } + + h := testHarness{ + t: t, + numChannels: numChannels, + graph: graph, + htlcSwitch: htlcSwitch, + mgr: mgr, + ourPubKey: cfg.OurPubKey, + safeDisableTimeout: (3 * cfg.ChanDisableTimeout) / 2, // 1.5x + } + + // Initialize link status as requested. + if startActive { + h.markActive(h.graph.channels) + } else { + h.markInactive(h.graph.channels) + } + + return h +} + +// markActive updates the active status of the passed channels within the mock +// switch to active. +func (h *testHarness) markActive(channels []*channeldb.OpenChannel) { + h.t.Helper() + + for _, channel := range channels { + chanID := lnwire.NewChanIDFromOutPoint(&channel.FundingOutpoint) + h.htlcSwitch.SetStatus(chanID, true) + } +} + +// markInactive updates the active status of the passed channels within the mock +// switch to inactive. +func (h *testHarness) markInactive(channels []*channeldb.OpenChannel) { + h.t.Helper() + + for _, channel := range channels { + chanID := lnwire.NewChanIDFromOutPoint(&channel.FundingOutpoint) + h.htlcSwitch.SetStatus(chanID, false) + } +} + +// assertEnables requests enables for all of the passed channels, and asserts +// that the errors returned from RequestEnable matches expErr. +func (h *testHarness) assertEnables(channels []*channeldb.OpenChannel, expErr error) { + h.t.Helper() + + for _, channel := range channels { + h.assertEnable(channel.FundingOutpoint, expErr) + } +} + +// assertDisables requests disables for all of the passed channels, and asserts +// that the errors returned from RequestDisable matches expErr. +func (h *testHarness) assertDisables(channels []*channeldb.OpenChannel, expErr error) { + h.t.Helper() + + for _, channel := range channels { + h.assertDisable(channel.FundingOutpoint, expErr) + } +} + +// assertEnable requests an enable for the given outpoint, and asserts that the +// returned error matches expErr. +func (h *testHarness) assertEnable(outpoint wire.OutPoint, expErr error) { + h.t.Helper() + + err := h.mgr.RequestEnable(outpoint) + if err != expErr { + h.t.Fatalf("expected enable error: %v, got %v", expErr, err) + } +} + +// assertDisable requests a disable for the given outpoint, and asserts that the +// returned error matches expErr. +func (h *testHarness) assertDisable(outpoint wire.OutPoint, expErr error) { + h.t.Helper() + + err := h.mgr.RequestDisable(outpoint) + if err != expErr { + h.t.Fatalf("expected disable error: %v, got %v", expErr, err) + } +} + +// assertNoUpdates waits for the specified duration, and asserts that no updates +// are announced on the network. +func (h *testHarness) assertNoUpdates(duration time.Duration) { + h.t.Helper() + + h.assertUpdates(nil, false, duration) +} + +// assertUpdates waits for the specified duration, asserting that an update +// are receive on the network for each of the passed OpenChannels, and that all +// of their disable bits are set to match expEnabled. The expEnabled parameter +// is ignored if channels is nil. +func (h *testHarness) assertUpdates(channels []*channeldb.OpenChannel, + expEnabled bool, duration time.Duration) { + + h.t.Helper() + + // Compute an index of the expected short channel ids for which we want + // to received updates. + expSids := sidsFromChans(channels) + + timeout := time.After(duration) + recvdSids := make(map[lnwire.ShortChannelID]struct{}) + for { + select { + case upd := <-h.graph.updates: + // Assert that the received short channel id is one that + // we expect. If no updates were expected, this will + // always fail on the first update received. + if _, ok := expSids[upd.ShortChannelID]; !ok { + h.t.Fatalf("received update for unexpected "+ + "short chan id: %v", upd.ShortChannelID) + } + + // Assert that the disabled bit is set properly. + enabled := upd.ChannelFlags&lnwire.ChanUpdateDisabled != + lnwire.ChanUpdateDisabled + if expEnabled != enabled { + h.t.Fatalf("expected enabled: %v, actual: %v", + expEnabled, enabled) + } + + recvdSids[upd.ShortChannelID] = struct{}{} + + case <-timeout: + // Time is up, assert that the correct number of unique + // updates was received. + if len(recvdSids) == len(channels) { + return + } + + h.t.Fatalf("expected %d updates, got %d", + len(channels), len(recvdSids)) + } + } +} + +// sidsFromChans returns an index contain the short channel ids of each channel +// provided in the list of OpenChannels. +func sidsFromChans( + channels []*channeldb.OpenChannel) map[lnwire.ShortChannelID]struct{} { + + sids := make(map[lnwire.ShortChannelID]struct{}) + for _, channel := range channels { + sids[channel.ShortChanID()] = struct{}{} + } + return sids +} + +type stateMachineTest struct { + name string + startEnabled bool + startActive bool + fn func(testHarness) +} + +var stateMachineTests = []stateMachineTest{ + { + name: "active and enabled is stable", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // No updates should be sent because being active and + // enabled should be a stable state. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "inactive and disabled is stable", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // No updates should be sent because being inactive and + // disabled should be a stable state. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "start disabled request enable", + startActive: true, // can't request enable unless active + startEnabled: false, + fn: func(h testHarness) { + // Request enables for all channels. + h.assertEnables(h.graph.chans(), nil) + // Expect to see them all enabled on the network. + h.assertUpdates( + h.graph.chans(), true, h.safeDisableTimeout, + ) + }, + }, + { + name: "start enabled request disable", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // Request disables for all channels. + h.assertDisables(h.graph.chans(), nil) + // Expect to see them all disabled on the network. + h.assertUpdates( + h.graph.chans(), false, h.safeDisableTimeout, + ) + }, + }, + { + name: "request enable already enabled", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // Request enables for already enabled channels. + h.assertEnables(h.graph.chans(), nil) + // Manager shouldn't send out any updates. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "request disabled already disabled", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // Request disables for already enabled channels. + h.assertDisables(h.graph.chans(), nil) + // Manager shouldn't sent out any updates. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "detect and disable inactive", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // Simulate disconnection and have links go inactive. + h.markInactive(h.graph.chans()) + // Should see all channels passively disabled. + h.assertUpdates( + h.graph.chans(), false, h.safeDisableTimeout, + ) + }, + }, + { + name: "quick flap stays active", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // Simulate disconnection and have links go inactive. + h.markInactive(h.graph.chans()) + // Allow 2 sample intervals to pass, but not long + // enough for a disable to occur. + time.Sleep(100 * time.Millisecond) + // Simulate reconnect by making channels active. + h.markActive(h.graph.chans()) + // Request that all channels be reenabled. + h.assertEnables(h.graph.chans(), nil) + // Pending disable should have been canceled, and + // no updates sent. Channels remain enabled on the + // network. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "no passive enable from becoming active", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // Simulate reconnect by making channels active. + h.markActive(h.graph.chans()) + // No updates should be sent without explicit enable. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "enable inactive channel fails", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // Request enable of inactive channels, expect error + // indicating that channel was not active. + h.assertEnables( + h.graph.chans(), netann.ErrEnableInactiveChan, + ) + // No updates should be sent as a result of the failure. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "enable unknown channel fails", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // Create channels unknown to the graph. + unknownChans := []*channeldb.OpenChannel{ + createChannel(h.t), + createChannel(h.t), + createChannel(h.t), + } + // Request that they be enabled, which should return an + // error as the graph doesn't have an edge for them. + h.assertEnables( + unknownChans, channeldb.ErrEdgeNotFound, + ) + // No updates should be sent as a result of the failure. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "disable unknown channel fails", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // Create channels unknown to the graph. + unknownChans := []*channeldb.OpenChannel{ + createChannel(h.t), + createChannel(h.t), + createChannel(h.t), + } + // Request that they be disabled, which should return an + // error as the graph doesn't have an edge for them. + h.assertDisables( + unknownChans, channeldb.ErrEdgeNotFound, + ) + // No updates should be sent as a result of the failure. + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, + { + name: "add new channels", + startActive: false, + startEnabled: false, + fn: func(h testHarness) { + // Allow the manager to enter a steady state for the + // initial channel set. + h.assertNoUpdates(h.safeDisableTimeout) + + // Add a new channels to the graph, but don't yet add + // the edge policies. We should see no updates sent + // since the manager can't access the policies. + newChans := []*channeldb.OpenChannel{ + createChannel(h.t), + createChannel(h.t), + createChannel(h.t), + } + for _, c := range newChans { + h.graph.addChannel(c) + } + h.assertNoUpdates(h.safeDisableTimeout) + + // Check that trying to enable the channel with unknown + // edges results in a failure. + h.assertEnables(newChans, channeldb.ErrEdgeNotFound) + + // Now, insert edge policies for the channel into the + // graph, starting with the channel enabled, and mark + // the link active. + for _, c := range newChans { + info, pol1, pol2 := createEdgePolicies( + h.t, c, h.ourPubKey, true, + ) + h.graph.addEdgePolicy(c, info, pol1, pol2) + } + h.markActive(newChans) + + // We expect no updates to be sent since the channel is + // enabled and active. + h.assertNoUpdates(h.safeDisableTimeout) + + // Finally, assert that enabling the channel doesn't + // return an error now that everything is in place. + h.assertEnables(newChans, nil) + }, + }, + { + name: "remove channels then disable", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // Allow the manager to enter a steady state for the + // initial channel set. + h.assertNoUpdates(h.safeDisableTimeout) + + // Select half of the current channels to remove. + channels := h.graph.chans() + rmChans := channels[:len(channels)/2] + + // Mark the channel inactive and remove them from the + // graph. This should trigger the manager to attempt a + // mark the channel disabled, but will unable to do so + // because it can't find the edge policies. + h.markInactive(rmChans) + for _, c := range rmChans { + h.graph.removeChannel(c) + } + h.assertNoUpdates(h.safeDisableTimeout) + + // Check that trying to enable the channel with unknown + // edges results in a failure. + h.assertDisables(rmChans, channeldb.ErrEdgeNotFound) + }, + }, + { + name: "disable channels then remove", + startActive: true, + startEnabled: true, + fn: func(h testHarness) { + // Allow the manager to enter a steady state for the + // initial channel set. + h.assertNoUpdates(h.safeDisableTimeout) + + // Select half of the current channels to remove. + channels := h.graph.chans() + rmChans := channels[:len(channels)/2] + + // Check that trying to enable the channel with unknown + // edges results in a failure. + h.assertDisables(rmChans, nil) + + // Since the channels are still in the graph, we expect + // these channels to be disabled on the network. + h.assertUpdates(rmChans, false, h.safeDisableTimeout) + + // Finally, remove the channels from the graph and + // assert no more updates are sent. + for _, c := range rmChans { + h.graph.removeChannel(c) + } + h.assertNoUpdates(h.safeDisableTimeout) + }, + }, +} + +// TestChanStatusManagerStateMachine tests the possible state transitions that +// can be taken by the ChanStatusManager. +func TestChanStatusManagerStateMachine(t *testing.T) { + t.Parallel() + + for _, test := range stateMachineTests { + tc := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + const numChannels = 10 + h := newHarness( + t, numChannels, tc.startActive, tc.startEnabled, + ) + defer h.mgr.Stop() + + tc.fn(h) + }) + } +}