routing/chainview: make filter updates synchronous for neutrino

This commit fixes a possible race condition wherein a call to
FilterBlock after a call to UpdateFilter would result in the call to
FilterBlock not yet using the updated filter. We fix this by ensuring
the internal chain filter is updated by the time the call to
FilterBlock returns.
This commit is contained in:
Olaoluwa Osuntokun 2017-06-09 12:12:01 -07:00
parent 0c134a8cb3
commit e4563ca13b
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
2 changed files with 22 additions and 8 deletions

@ -361,6 +361,7 @@ func (b *BtcdFilteredChainView) chainFilterer() {
type filterUpdate struct { type filterUpdate struct {
newUtxos []wire.OutPoint newUtxos []wire.OutPoint
updateHeight uint32 updateHeight uint32
done chan struct{}
} }
// UpdateFilter updates the UTXO filter which is to be consulted when creating // UpdateFilter updates the UTXO filter which is to be consulted when creating

@ -225,6 +225,10 @@ func (c *CfFilteredChainView) chainFilterer() {
log.Errorf("unable to update rescan: %v", err) log.Errorf("unable to update rescan: %v", err)
} }
if update.done != nil {
close(update.done)
}
case <-c.quit: case <-c.quit:
return return
} }
@ -253,11 +257,11 @@ func (c *CfFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredB
// If we don't have any items within our current chain filter, then we // If we don't have any items within our current chain filter, then we
// can exit early as we don't need to fetch the filter. // can exit early as we don't need to fetch the filter.
c.filterMtx.RLock() c.filterMtx.RLock()
numPoints := len(c.chainFilter) if len(c.chainFilter) == 0 {
c.filterMtx.RUnlock() c.filterMtx.RUnlock()
if numPoints == 0 {
return filteredBlock, nil return filteredBlock, nil
} }
c.filterMtx.RUnlock()
// Next, using the block, hash, we'll fetch the compact filter for this // Next, using the block, hash, we'll fetch the compact filter for this
// block. We only require the regular filter as we're just looking for // block. We only require the regular filter as we're just looking for
@ -338,17 +342,26 @@ func (c *CfFilteredChainView) FilterBlock(blockHash *chainhash.Hash) (*FilteredB
// //
// NOTE: This is part of the FilteredChainView interface. // NOTE: This is part of the FilteredChainView interface.
func (c *CfFilteredChainView) UpdateFilter(ops []wire.OutPoint, updateHeight uint32) error { func (c *CfFilteredChainView) UpdateFilter(ops []wire.OutPoint, updateHeight uint32) error {
select { doneChan := make(chan struct{})
update := filterUpdate{
case c.filterUpdates <- filterUpdate{
newUtxos: ops, newUtxos: ops,
updateHeight: updateHeight, updateHeight: updateHeight,
}: done: doneChan,
return nil }
select {
case c.filterUpdates <- update:
case <-c.quit: case <-c.quit:
return fmt.Errorf("chain filter shutting down") return fmt.Errorf("chain filter shutting down")
} }
select {
case <-doneChan:
return nil
case <-c.quit:
return fmt.Errorf("chain filter shutting down")
}
} }
// FilteredBlocks returns the channel that filtered blocks are to be sent over. // FilteredBlocks returns the channel that filtered blocks are to be sent over.