peer: ensure access to activeChannels and htlcManagers is thread safe

This commit is contained in:
Olaoluwa Osuntokun 2016-11-17 18:43:33 -08:00
parent 81e65e00e5
commit d98cac432b
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2

40
peer.go

@ -105,9 +105,11 @@ type peer struct {
// activeChannels is a map which stores the state machines of all // activeChannels is a map which stores the state machines of all
// active channels. Channels are indexed into the map by the txid of // active channels. Channels are indexed into the map by the txid of
// the funding transaction which opened the channel. // the funding transaction which opened the channel.
activeChanMtx sync.RWMutex
activeChannels map[wire.OutPoint]*lnwallet.LightningChannel activeChannels map[wire.OutPoint]*lnwallet.LightningChannel
chanSnapshotReqs chan *chanSnapshotReq chanSnapshotReqs chan *chanSnapshotReq
htlcManMtx sync.RWMutex
htlcManagers map[wire.OutPoint]chan lnwire.Message htlcManagers map[wire.OutPoint]chan lnwire.Message
// newChanBarriers is a map from a channel point to a 'barrier' which // newChanBarriers is a map from a channel point to a 'barrier' which
@ -225,7 +227,10 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
Hash: chanID.Hash, Hash: chanID.Hash,
Index: chanID.Index, Index: chanID.Index,
} }
p.activeChanMtx.Lock()
p.activeChannels[chanPoint] = lnChan p.activeChannels[chanPoint] = lnChan
p.activeChanMtx.Unlock()
peerLog.Infof("peerID(%v) loaded ChannelPoint(%v)", p.id, chanPoint) peerLog.Infof("peerID(%v) loaded ChannelPoint(%v)", p.id, chanPoint)
// Notify the routing table of this newly loaded channel. // Notify the routing table of this newly loaded channel.
@ -249,7 +254,10 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
dbChan.Snapshot(), downstreamLink) dbChan.Snapshot(), downstreamLink)
upstreamLink := make(chan lnwire.Message, 10) upstreamLink := make(chan lnwire.Message, 10)
p.htlcManMtx.Lock()
p.htlcManagers[chanPoint] = upstreamLink p.htlcManagers[chanPoint] = upstreamLink
p.htlcManMtx.Unlock()
p.wg.Add(1) p.wg.Add(1)
go p.htlcManager(lnChan, plexChan, downstreamLink, upstreamLink) go p.htlcManager(lnChan, plexChan, downstreamLink, upstreamLink)
} }
@ -440,7 +448,9 @@ out:
// Dispatch the commitment update message to the proper // Dispatch the commitment update message to the proper
// active goroutine dedicated to this channel. // active goroutine dedicated to this channel.
p.htlcManMtx.Lock()
targetChan, ok := p.htlcManagers[*targetChan] targetChan, ok := p.htlcManagers[*targetChan]
p.htlcManMtx.Unlock()
if !ok { if !ok {
peerLog.Errorf("recv'd update for unknown channel %v", peerLog.Errorf("recv'd update for unknown channel %v",
targetChan) targetChan)
@ -623,11 +633,13 @@ out:
for { for {
select { select {
case req := <-p.chanSnapshotReqs: case req := <-p.chanSnapshotReqs:
p.activeChanMtx.RLock()
snapshots := make([]*channeldb.ChannelSnapshot, 0, len(p.activeChannels)) snapshots := make([]*channeldb.ChannelSnapshot, 0, len(p.activeChannels))
for _, activeChan := range p.activeChannels { for _, activeChan := range p.activeChannels {
snapshot := activeChan.StateSnapshot() snapshot := activeChan.StateSnapshot()
snapshots = append(snapshots, snapshot) snapshots = append(snapshots, snapshot)
} }
p.activeChanMtx.RUnlock()
req.resp <- snapshots req.resp <- snapshots
case pendingChanPoint := <-p.barrierInits: case pendingChanPoint := <-p.barrierInits:
@ -644,7 +656,10 @@ out:
case newChan := <-p.newChannels: case newChan := <-p.newChannels:
chanPoint := *newChan.ChannelPoint() chanPoint := *newChan.ChannelPoint()
p.activeChanMtx.Lock()
p.activeChannels[chanPoint] = newChan p.activeChannels[chanPoint] = newChan
p.activeChanMtx.Unlock()
peerLog.Infof("New channel active ChannelPoint(%v) "+ peerLog.Infof("New channel active ChannelPoint(%v) "+
"with peerId(%v)", chanPoint, p.id) "with peerId(%v)", chanPoint, p.id)
@ -660,7 +675,10 @@ out:
// a goroutine to handle commitment updates for this // a goroutine to handle commitment updates for this
// new channel. // new channel.
upstreamLink := make(chan lnwire.Message, 10) upstreamLink := make(chan lnwire.Message, 10)
p.htlcManMtx.Lock()
p.htlcManagers[chanPoint] = upstreamLink p.htlcManagers[chanPoint] = upstreamLink
p.htlcManMtx.Unlock()
p.wg.Add(1) p.wg.Add(1)
go p.htlcManager(newChan, plexChan, downstreamLink, upstreamLink) go p.htlcManager(newChan, plexChan, downstreamLink, upstreamLink)
@ -761,7 +779,9 @@ func (p *peer) handleLocalClose(req *closeLinkReq) {
closingTxid *wire.ShaHash closingTxid *wire.ShaHash
) )
p.activeChanMtx.RLock()
channel := p.activeChannels[*req.chanPoint] channel := p.activeChannels[*req.chanPoint]
p.activeChanMtx.RUnlock()
if req.forceClose { if req.forceClose {
closingTxid, err = p.executeForceClose(channel) closingTxid, err = p.executeForceClose(channel)
@ -778,7 +798,8 @@ func (p *peer) handleLocalClose(req *closeLinkReq) {
return return
} }
// Update the caller w.r.t the current pending state of this request. // Update the caller with a new event detailing the current pending
// state of this request.
req.updates <- &lnrpc.CloseStatusUpdate{ req.updates <- &lnrpc.CloseStatusUpdate{
Update: &lnrpc.CloseStatusUpdate_ClosePending{ Update: &lnrpc.CloseStatusUpdate_ClosePending{
ClosePending: &lnrpc.PendingUpdate{ ClosePending: &lnrpc.PendingUpdate{
@ -842,7 +863,10 @@ func (p *peer) handleRemoteClose(req *lnwire.CloseRequest) {
Hash: chanPoint.Hash, Hash: chanPoint.Hash,
Index: chanPoint.Index, Index: chanPoint.Index,
} }
p.activeChanMtx.RLock()
channel := p.activeChannels[key] channel := p.activeChannels[key]
p.activeChanMtx.RUnlock()
// Now that we have their signature for the closure transaction, we // Now that we have their signature for the closure transaction, we
// can assemble the final closure transaction, complete with our // can assemble the final closure transaction, complete with our
@ -883,17 +907,26 @@ func (p *peer) handleRemoteClose(req *lnwire.CloseRequest) {
func wipeChannel(p *peer, channel *lnwallet.LightningChannel) error { func wipeChannel(p *peer, channel *lnwallet.LightningChannel) error {
chanID := channel.ChannelPoint() chanID := channel.ChannelPoint()
p.activeChanMtx.Lock()
delete(p.activeChannels, *chanID) delete(p.activeChannels, *chanID)
p.activeChanMtx.Unlock()
// Instruct the Htlc Switch to close this link as the channel is no // Instruct the Htlc Switch to close this link as the channel is no
// longer active. // longer active.
p.server.htlcSwitch.UnregisterLink(p.addr.IdentityKey, chanID) p.server.htlcSwitch.UnregisterLink(p.addr.IdentityKey, chanID)
p.htlcManMtx.RLock()
htlcWireLink, ok := p.htlcManagers[*chanID] htlcWireLink, ok := p.htlcManagers[*chanID]
if !ok { if !ok {
p.htlcManMtx.RUnlock()
return nil return nil
} }
p.htlcManMtx.RUnlock()
p.htlcManMtx.RLock()
delete(p.htlcManagers, *chanID) delete(p.htlcManagers, *chanID)
p.htlcManMtx.RUnlock()
close(htlcWireLink) close(htlcWireLink)
if err := channel.DeleteState(); err != nil { if err := channel.DeleteState(); err != nil {
@ -1013,13 +1046,14 @@ out:
for { for {
select { select {
case <-channel.UnilateralCloseSignal: case <-channel.UnilateralCloseSignal:
// TODO(roasbeef): eliminate false positive via local close
peerLog.Warnf("Remote peer has closed ChannelPoint(%v) on-chain", peerLog.Warnf("Remote peer has closed ChannelPoint(%v) on-chain",
state.chanPoint) state.chanPoint)
if err := wipeChannel(p, channel); err != nil { if err := wipeChannel(p, channel); err != nil {
peerLog.Errorf("unable to wipe channel %v", err) peerLog.Errorf("unable to wipe channel %v", err)
} }
// TODO(roasbeef): send info about current HTLC's to
// utxoNursery
break out break out
case <-channel.ForceCloseSignal: case <-channel.ForceCloseSignal:
peerLog.Warnf("ChannelPoint(%v) has been force "+ peerLog.Warnf("ChannelPoint(%v) has been force "+
@ -1327,7 +1361,7 @@ func (p *peer) handleUpstreamMsg(state *commitmentState, msg lnwire.Message) {
for _, htlc := range htlcsToForward { for _, htlc := range htlcsToForward {
// We don't need to forward any HTLC's that we // We don't need to forward any HTLC's that we
// just settled above. // just settled above.
// TODO(roasbeef): key by index insteaad? // TODO(roasbeef): key by index instead?
if _, ok := settledPayments[htlc.RHash]; ok { if _, ok := settledPayments[htlc.RHash]; ok {
continue continue
} }