diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index 5dffef55..695ec1e7 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -2,6 +2,7 @@ package wtclient import ( "container/list" + "net" "sync" "github.com/lightningnetwork/lnd/watchtower/wtdb" @@ -10,6 +11,20 @@ import ( // TowerCandidateIterator provides an abstraction for iterating through possible // watchtower addresses when attempting to create a new session. type TowerCandidateIterator interface { + // AddCandidate adds a new candidate tower to the iterator. If the + // candidate already exists, then any new addresses are added to it. + AddCandidate(*wtdb.Tower) + + // RemoveCandidate removes an existing candidate tower from the + // iterator. An optional address can be provided to indicate a stale + // tower address to remove it. If it isn't provided, then the tower is + // completely removed from the iterator. + RemoveCandidate(wtdb.TowerID, net.Addr) + + // IsActive determines whether a given tower is exists within the + // iterator. + IsActive(wtdb.TowerID) bool + // Reset clears any internal iterator state, making previously taken // candidates available as long as they remain in the set. Reset() error @@ -18,17 +33,14 @@ type TowerCandidateIterator interface { // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. Next() (*wtdb.Tower, error) - - // TowerIDs returns the set of tower IDs contained in the iterator, - // which can be used to filter candidate sessions for the active tower. - TowerIDs() map[wtdb.TowerID]struct{} } // towerListIterator is a linked-list backed TowerCandidateIterator. type towerListIterator struct { mu sync.Mutex - candidates *list.List + queue *list.List nextCandidate *list.Element + candidates map[wtdb.TowerID]*wtdb.Tower } // Compile-time constraint to ensure *towerListIterator implements the @@ -39,11 +51,13 @@ var _ TowerCandidateIterator = (*towerListIterator)(nil) // of lnwire.NetAddresses. func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator { iter := &towerListIterator{ - candidates: list.New(), + queue: list.New(), + candidates: make(map[wtdb.TowerID]*wtdb.Tower), } for _, candidate := range candidates { - iter.candidates.PushBack(candidate) + iter.queue.PushBack(candidate.ID) + iter.candidates[candidate.ID] = candidate } iter.Reset() @@ -57,22 +71,11 @@ func (t *towerListIterator) Reset() error { defer t.mu.Unlock() // Reset the next candidate to the front of the linked-list. - t.nextCandidate = t.candidates.Front() + t.nextCandidate = t.queue.Front() return nil } -// TowerIDs returns the set of tower IDs contained in the iterator, which can be -// used to filter candidate sessions for the active tower. -func (t *towerListIterator) TowerIDs() map[wtdb.TowerID]struct{} { - ids := make(map[wtdb.TowerID]struct{}) - for e := t.candidates.Front(); e != nil; e = e.Next() { - tower := e.Value.(*wtdb.Tower) - ids[tower.ID] = struct{}{} - } - return ids -} - // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. @@ -80,18 +83,76 @@ func (t *towerListIterator) Next() (*wtdb.Tower, error) { t.mu.Lock() defer t.mu.Unlock() - // If the next candidate is nil, we've exhausted the list. - if t.nextCandidate == nil { - return nil, ErrTowerCandidatesExhausted + for t.nextCandidate != nil { + // Propose the tower at the front of the list. + towerID := t.nextCandidate.Value.(wtdb.TowerID) + + // Check whether this tower is still considered a candidate. If + // it's not, we'll proceed to the next. + tower, ok := t.candidates[towerID] + if !ok { + nextCandidate := t.nextCandidate.Next() + t.queue.Remove(t.nextCandidate) + t.nextCandidate = nextCandidate + continue + } + + // Set the next candidate to the subsequent element. + t.nextCandidate = t.nextCandidate.Next() + return tower, nil } - // Propose the tower at the front of the list. - tower := t.nextCandidate.Value.(*wtdb.Tower) + return nil, ErrTowerCandidatesExhausted +} - // Set the next candidate to the subsequent element. - t.nextCandidate = t.nextCandidate.Next() +// AddCandidate adds a new candidate tower to the iterator. If the candidate +// already exists, then any new addresses are added to it. +func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { + t.mu.Lock() + defer t.mu.Unlock() - return tower, nil + if tower, ok := t.candidates[candidate.ID]; !ok { + t.queue.PushBack(candidate.ID) + t.candidates[candidate.ID] = candidate + + // If we've reached the end of our queue, then this candidate + // will become the next. + if t.nextCandidate == nil { + t.nextCandidate = t.queue.Back() + } + } else { + for _, addr := range candidate.Addresses { + tower.AddAddress(addr) + } + } +} + +// RemoveCandidate removes an existing candidate tower from the iterator. An +// optional address can be provided to indicate a stale tower address to remove +// it. If it isn't provided, then the tower is completely removed from the +// iterator. +func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID, addr net.Addr) { + t.mu.Lock() + defer t.mu.Unlock() + + tower, ok := t.candidates[candidate] + if !ok { + return + } + if addr != nil { + tower.RemoveAddress(addr) + } else { + delete(t.candidates, candidate) + } +} + +// IsActive determines whether a given tower is exists within the iterator. +func (t *towerListIterator) IsActive(tower wtdb.TowerID) bool { + t.mu.Lock() + defer t.mu.Unlock() + + _, ok := t.candidates[tower] + return ok } // TODO(conner): implement graph-backed candidate iterator for public towers. diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go new file mode 100644 index 00000000..63e3a729 --- /dev/null +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -0,0 +1,157 @@ +package wtclient + +import ( + "encoding/binary" + "math/rand" + "net" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +func init() { + rand.Seed(time.Now().Unix()) +} + +func randAddr(t *testing.T) net.Addr { + var ip [4]byte + if _, err := rand.Read(ip[:]); err != nil { + t.Fatal(err) + } + var port [2]byte + if _, err := rand.Read(port[:]); err != nil { + t.Fatal(err) + + } + return &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } +} + +func randTower(t *testing.T) *wtdb.Tower { + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to create private key: %v", err) + } + pubKey := priv.PubKey() + pubKey.Curve = nil + return &wtdb.Tower{ + ID: wtdb.TowerID(rand.Uint64()), + IdentityKey: pubKey, + Addresses: []net.Addr{randAddr(t)}, + } +} + +func copyTower(tower *wtdb.Tower) *wtdb.Tower { + t := &wtdb.Tower{ + ID: tower.ID, + IdentityKey: tower.IdentityKey, + Addresses: make([]net.Addr, len(tower.Addresses)), + } + copy(t.Addresses, tower.Addresses) + return t +} + +func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, + c *wtdb.Tower, active bool) { + + isCandidate := i.IsActive(c.ID) + if isCandidate && !active { + t.Fatalf("expected tower %v to no longer be an active candidate", + c.ID) + } + if !isCandidate && active { + t.Fatalf("expected tower %v to be an active candidate", c.ID) + } +} + +func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { + t.Helper() + + tower, err := i.Next() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tower, c) { + t.Fatalf("expected tower: %v\ngot: %v", spew.Sdump(c), + spew.Sdump(tower)) + } +} + +// TestTowerCandidateIterator asserts the internal state of a +// TowerCandidateIterator after a series of updates to its candidates. +func TestTowerCandidateIterator(t *testing.T) { + t.Parallel() + + // We'll start our test by creating an iterator of four candidate + // towers. We'll use copies of these towers within the iterator to + // ensure the iterator properly updates the state of its candidates. + const numTowers = 4 + towers := make([]*wtdb.Tower, 0, numTowers) + for i := 0; i < numTowers; i++ { + towers = append(towers, randTower(t)) + } + towerCopies := make([]*wtdb.Tower, 0, numTowers) + for _, tower := range towers { + towerCopies = append(towerCopies, copyTower(tower)) + } + towerIterator := newTowerListIterator(towerCopies...) + + // We should expect to see all of our candidates in the order that they + // were added. + for _, expTower := range towers { + tower, err := towerIterator.Next() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tower, expTower) { + t.Fatalf("expected tower: %v\ngot: %v", + spew.Sdump(expTower), spew.Sdump(tower)) + } + } + + if _, err := towerIterator.Next(); err != ErrTowerCandidatesExhausted { + t.Fatalf("expected ErrTowerCandidatesExhausted, got %v", err) + } + towerIterator.Reset() + + // We'll then attempt to test the RemoveCandidate behavior of the + // iterator. We'll remove the address of the first tower, which should + // result in it not having any addresses left, but still being an active + // candidate. + firstTower := towers[0] + firstTowerAddr := firstTower.Addresses[0] + firstTower.RemoveAddress(firstTowerAddr) + towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + assertActiveCandidate(t, towerIterator, firstTower, true) + assertNextCandidate(t, towerIterator, firstTower) + + // We'll then remove the second tower completely from the iterator by + // not providing the optional address. Since it's been removed, we + // should expect to see the third tower next. + secondTower, thirdTower := towers[1], towers[2] + towerIterator.RemoveCandidate(secondTower.ID, nil) + assertActiveCandidate(t, towerIterator, secondTower, false) + assertNextCandidate(t, towerIterator, thirdTower) + + // We'll then update the fourth candidate with a new address. A + // duplicate shouldn't be added since it already exists within the + // iterator, but the new address should be. + fourthTower := towers[3] + assertActiveCandidate(t, towerIterator, fourthTower, true) + fourthTower.AddAddress(randAddr(t)) + towerIterator.AddCandidate(fourthTower) + assertNextCandidate(t, towerIterator, fourthTower) + + // Finally, we'll attempt to add a new candidate to the end of the + // iterator. Since it didn't already exist and we've reached the end, it + // should be available as the next candidate. + towerIterator.AddCandidate(secondTower) + assertActiveCandidate(t, towerIterator, secondTower, true) + assertNextCandidate(t, towerIterator, secondTower) +} diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index b8cf03cb..f9aa085d 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -149,9 +149,9 @@ type TowerClient struct { pipeline *taskPipeline negotiator SessionNegotiator + candidateTowers TowerCandidateIterator candidateSessions map[wtdb.SessionID]*wtdb.ClientSession activeSessions sessionQueueSet - targetTowerIDs map[wtdb.TowerID]struct{} sessionQueue *sessionQueue prevTask *backupTask @@ -199,8 +199,7 @@ func New(config *Config) (*TowerClient, error) { log.Infof("Using private watchtower %s, offering policy %s", cfg.PrivateTower, cfg.Policy) - candidates := newTowerListIterator(tower) - targetTowerIDs := candidates.TowerIDs() + candidateTowers := newTowerListIterator(tower) // Next, load all active sessions from the db into the client. We will // use any of these session if their policies match the current policy @@ -243,9 +242,9 @@ func New(config *Config) (*TowerClient, error) { c := &TowerClient{ cfg: cfg, pipeline: newTaskPipeline(), + candidateTowers: candidateTowers, candidateSessions: sessions, activeSessions: make(sessionQueueSet), - targetTowerIDs: targetTowerIDs, summaries: chanSummaries, statTicker: time.NewTicker(DefaultStatInterval), forceQuit: make(chan struct{}), @@ -258,7 +257,7 @@ func New(config *Config) (*TowerClient, error) { SendMessage: c.sendMessage, ReadMessage: c.readMessage, Dial: c.dial, - Candidates: candidates, + Candidates: c.candidateTowers, MinBackoff: cfg.MinBackoff, MaxBackoff: cfg.MaxBackoff, }) @@ -535,7 +534,7 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue { // Skip any sessions that are still active, but are not for the // users currently configured tower. - if _, ok := c.targetTowerIDs[sessionInfo.TowerID]; !ok { + if !c.candidateTowers.IsActive(sessionInfo.TowerID) { continue }