not saying it works, but it works better

redo how db and header sync works, somewhat simpler but a little
recursve-ish.  There is still an off by 1 error somewhere with headers.
This commit is contained in:
Tadge Dryja 2016-01-28 19:35:49 -08:00
parent d9afd623eb
commit 6ef9dc3d4a
4 changed files with 178 additions and 84 deletions

@ -45,6 +45,8 @@ type SPVCon struct {
// mBlockQueue is for keeping track of what height we've requested.
mBlockQueue chan HashAndHeight
// fPositives is a channel to keep track of bloom filter false positives.
fPositives chan int32
}
func OpenSPV(remoteNode string, hfn, tsfn string,
@ -123,11 +125,13 @@ func OpenSPV(remoteNode string, hfn, tsfn string,
}
s.WBytes += uint64(n)
s.inMsgQueue = make(chan wire.Message, 1)
s.inMsgQueue = make(chan wire.Message)
go s.incomingMessageHandler()
s.outMsgQueue = make(chan wire.Message, 1)
s.outMsgQueue = make(chan wire.Message)
go s.outgoingMessageHandler()
s.mBlockQueue = make(chan HashAndHeight, 32) // queue depth 32 is a thing
s.fPositives = make(chan int32, 4000) // a block full, approx
go s.fPositiveHandler()
return s, nil
}
@ -223,18 +227,10 @@ func (s *SPVCon) AskForTx(txid wire.ShaHash) {
// appending and checking the header, and checking spv proofs
func (s *SPVCon) AskForBlock(hsh wire.ShaHash) {
fmt.Printf("mBlockQueue len %d\n", len(s.mBlockQueue))
// wait until all mblocks are done before adding
for len(s.mBlockQueue) != 0 {
// fmt.Printf("mBlockQueue len %d\n", len(s.mBlockQueue))
}
gdata := wire.NewMsgGetData()
inv := wire.NewInvVect(wire.InvTypeFilteredBlock, &hsh)
gdata.AddInvVect(inv)
// TODO - wait until headers are sync'd before checking height
info, err := s.headerFile.Stat() // get
if err != nil {
log.Fatal(err) // crash if header file disappears
@ -242,10 +238,9 @@ func (s *SPVCon) AskForBlock(hsh wire.ShaHash) {
nextHeight := int32(info.Size() / 80)
hah := NewRootAndHeight(hsh, nextHeight)
fmt.Printf("AskForBlock - %s height %d\n", hsh.String(), nextHeight)
s.mBlockQueue <- hah // push height and mroot of requested block on queue
s.outMsgQueue <- gdata // push request to outbox
}
func (s *SPVCon) AskForHeaders() error {
@ -304,19 +299,25 @@ func (s *SPVCon) IngestMerkleBlock(m *wire.MsgMerkleBlock) error {
if !hah.blockhash.IsEqual(&newMerkBlockSha) {
return fmt.Errorf("merkle block out of order error")
}
for _, txid := range txids {
err := s.TS.AddTxid(txid, hah.height)
if err != nil {
return fmt.Errorf("Txid store error: %s\n", err.Error())
}
}
err = s.TS.SetDBSyncHeight(hah.height)
if err != nil {
return err
}
return nil
}
// IngestHeaders takes in a bunch of headers and appends them to the
// local header file, checking that they fit. If there's no headers,
// it assumes we're done and returns false. If it worked it assumes there's
// more to request and returns true.9
// more to request and returns true.
func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) {
var err error
// seek to last header
@ -361,7 +362,7 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) {
return false, fmt.Errorf("couldn't truncate header file")
}
}
return false, fmt.Errorf("Truncated header file to try again")
return true, fmt.Errorf("Truncated header file to try again")
}
for _, resphdr := range m.Headers {
@ -393,7 +394,19 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) {
}
}
log.Printf("Headers to height %d OK.", tip)
// if we got post DB syncheight headers, get merkleblocks for them
// this is always true except for first pre-birthday sync
syncTip, err := s.TS.GetDBSyncHeight()
if err != nil {
return false, err
}
if syncTip < tip {
err = s.AskForMerkBlocks(syncTip, tip)
if err != nil {
return false, err
}
}
return true, nil
}
@ -421,7 +434,7 @@ func (s *SPVCon) PushTx(tx *wire.MsgTx) error {
if err != nil {
return err
}
err = s.TS.AckTx(tx)
_, err = s.TS.AckTx(tx) // our own tx so don't need to track relevance
if err != nil {
return err
}
@ -429,27 +442,32 @@ func (s *SPVCon) PushTx(tx *wire.MsgTx) error {
return nil
}
func (s *SPVCon) GetNextHeaderHeight() (int32, error) {
info, err := s.headerFile.Stat() // get
if err != nil {
return 0, err // crash if header file disappears
}
nextHeight := int32(info.Size() / 80)
return nextHeight, nil
}
// AskForMerkBlocks requests blocks from current to last
// right now this asks for 1 block per getData message.
// Maybe it's faster to ask for many in a each message?
func (s *SPVCon) AskForMerkBlocks(current, last int32) error {
var hdr wire.BlockHeader
info, err := s.headerFile.Stat() // get
nextHeight, err := s.GetNextHeaderHeight()
if err != nil {
return err // crash if header file disappears
return err
}
nextHeight := int32(info.Size() / 80)
fmt.Printf("have headers up to height %d\n", nextHeight-1)
// if last is 0, that means go as far as we can
if last == 0 {
last = nextHeight - 1
}
fmt.Printf("will request merkleblocks %d to %d\n", current, last)
// track number of utxos
track, err := s.TS.NumUtxos()
if err != nil {
return err
}
// create initial filter
filt, err := s.TS.GimmeFilter()
@ -467,21 +485,6 @@ func (s *SPVCon) AskForMerkBlocks(current, last int32) error {
// loop through all heights where we want merkleblocks.
for current < last {
// check if we need to update filter... diff of 5 utxos...?
nTrack, err := s.TS.NumUtxos()
if err != nil {
return err
}
if track < nTrack-4 || track > nTrack+4 {
track = nTrack
filt, err := s.TS.GimmeFilter()
if err != nil {
return err
}
s.SendFilter(filt)
fmt.Printf("sent %d byte filter\n", len(filt.MsgFilterLoad().Filter))
}
// load header from file
err = hdr.Deserialize(s.headerFile)
@ -503,5 +506,8 @@ func (s *SPVCon) AskForMerkBlocks(current, last int32) error {
s.mBlockQueue <- hah // push height and mroot of requested block on queue
current++
}
// done syncing blocks known in header file, ask for new headers we missed
s.AskForHeaders()
return nil
}

@ -1,6 +1,7 @@
package uspv
import (
"fmt"
"log"
"github.com/btcsuite/btcd/wire"
@ -25,7 +26,7 @@ func (s *SPVCon) incomingMessageHandler() {
case *wire.MsgAddr:
log.Printf("got %d addresses.\n", len(m.AddrList))
case *wire.MsgPing:
log.Printf("Got a ping message. We should pong back or they will kick us off.")
// log.Printf("Got a ping message. We should pong back or they will kick us off.")
s.PongBack(m.Nonce)
case *wire.MsgPong:
log.Printf("Got a pong response. OK.\n")
@ -45,9 +46,15 @@ func (s *SPVCon) incomingMessageHandler() {
s.AskForHeaders()
}
case *wire.MsgTx:
err := s.TS.AckTx(m)
hits, err := s.TS.AckTx(m)
if err != nil {
log.Printf("Incoming Tx error: %s\n", err.Error())
continue
}
if hits == 0 {
log.Printf("tx %s had no hits, filter false positive.",
m.TxSha().String())
s.fPositives <- 1 // add one false positive to chan
}
// log.Printf("Got tx %s\n", m.TxSha().String())
case *wire.MsgReject:
@ -83,6 +90,34 @@ func (s *SPVCon) outgoingMessageHandler() {
return
}
// fPositiveHandler monitors false positives and when it gets enough of them,
//
func (s *SPVCon) fPositiveHandler() {
var fpAccumulator int32
for {
fpAccumulator += <-s.fPositives // blocks here
if fpAccumulator > 7 {
filt, err := s.TS.GimmeFilter()
if err != nil {
log.Printf("Filter creation error: %s\n", err.Error())
log.Printf("uhoh, crashing filter handler")
return
}
// send filter
s.SendFilter(filt)
fmt.Printf("sent filter %x\n", filt.MsgFilterLoad().Filter)
// clear the channel
for len(s.fPositives) != 0 {
fpAccumulator += <-s.fPositives
}
fmt.Printf("reset %d false positives\n", fpAccumulator)
// reset accumulator
fpAccumulator = 0
}
}
}
func (s *SPVCon) InvHandler(m *wire.MsgInv) {
log.Printf("got inv. Contains:\n")
for i, thing := range m.InvList {
@ -93,7 +128,10 @@ func (s *SPVCon) InvHandler(m *wire.MsgInv) {
s.AskForTx(thing.Hash)
}
if thing.Type == wire.InvTypeBlock { // new block, ingest
s.AskForBlock(thing.Hash)
if len(s.mBlockQueue) == 0 {
// don't ask directly; instead ask for header
s.AskForHeaders()
}
}
}
}

@ -77,7 +77,7 @@ func (t *TxStore) AddTxid(txid *wire.ShaHash, height int32) error {
if txid == nil {
return fmt.Errorf("tried to add nil txid")
}
log.Printf("added %s at height %d\n", txid.String(), height)
log.Printf("added %s to OKTxids at height %d\n", txid.String(), height)
t.OKTxids[*txid] = height
return nil
}
@ -94,7 +94,7 @@ func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
}
elem := uint32(len(t.Adrs)) + nutxo
f := bloom.NewFilter(elem, 0, 0.001, wire.BloomUpdateAll)
f := bloom.NewFilter(elem, 0, 0.000001, wire.BloomUpdateAll)
for _, a := range t.Adrs {
f.Add(a.PkhAdr.ScriptAddress())
}
@ -107,30 +107,35 @@ func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
}
// Ingest a tx into wallet, dealing with both gains and losses
func (t *TxStore) AckTx(tx *wire.MsgTx) error {
func (t *TxStore) AckTx(tx *wire.MsgTx) (uint32, error) {
var ioHits uint32 // number of utxos changed due to this tx
inTxid := tx.TxSha()
height, ok := t.OKTxids[inTxid]
if !ok {
log.Printf("False postive tx? %s", TxToString(tx))
return fmt.Errorf("we don't care about tx %s", inTxid.String())
log.Printf("%s", TxToString(tx))
return 0, fmt.Errorf("tx %s not in OKTxids.", inTxid.String())
}
delete(t.OKTxids, inTxid) // don't need anymore
hitsGained, err := t.AbsorbTx(tx, height)
if err != nil {
return 0, err
}
hitsLost, err := t.ExpellTx(tx, height)
if err != nil {
return 0, err
}
ioHits = hitsGained + hitsLost
err := t.AbsorbTx(tx, height)
if err != nil {
return err
}
err = t.ExpellTx(tx, height)
if err != nil {
return err
}
// fmt.Printf("ingested tx %s total amt %d\n", inTxid.String(), t.Sum)
return nil
return ioHits, nil
}
// Absorb money into wallet from a tx
func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) error {
// AbsorbTx Absorbs money into wallet from a tx. returns number of
// new utxos absorbed.
func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) (uint32, error) {
if tx == nil {
return fmt.Errorf("Tried to add nil tx")
return 0, fmt.Errorf("Tried to add nil tx")
}
newTxid := tx.TxSha()
var hits uint32 // how many outputs of this tx are ours
@ -145,11 +150,12 @@ func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) error {
dup = OutPointsEqual(*newOp, u.Op) // is this outpoint known?
if dup { // found dupe
fmt.Printf(" %s is dupe\t", newOp.String())
hits++ // thought a dupe, still a hit
u.AtHeight = height // ONLY difference is height
// save modified utxo to db, overwriting old one
err := t.SaveUtxo(u)
if err != nil {
return err
return 0, err
}
break // out of the t.Utxo range loop
}
@ -164,7 +170,7 @@ func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) error {
// check for full script to eliminate false positives
aPKscript, err := txscript.PayToAddrScript(a.PkhAdr)
if err != nil {
return err
return 0, err
}
if bytes.Equal(out.PkScript, aPKscript) { // hit
// already checked for dupes, so this must be a new outpoint
@ -179,7 +185,7 @@ func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) error {
newu.Op = newop
err = t.SaveUtxo(&newu)
if err != nil {
return err
return 0, err
}
acq += out.Value
@ -189,15 +195,15 @@ func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) error {
}
}
}
log.Printf("%d hits, acquired %d", hits, acq)
// log.Printf("%d hits, acquired %d", hits, acq)
t.Sum += acq
return nil
return hits, nil
}
// Expell money from wallet due to a tx
func (t *TxStore) ExpellTx(tx *wire.MsgTx, height int32) error {
func (t *TxStore) ExpellTx(tx *wire.MsgTx, height int32) (uint32, error) {
if tx == nil {
return fmt.Errorf("Tried to add nil tx")
return 0, fmt.Errorf("Tried to add nil tx")
}
var hits uint32
var loss int64
@ -209,16 +215,16 @@ func (t *TxStore) ExpellTx(tx *wire.MsgTx, height int32) error {
loss += myutxo.Value
err := t.MarkSpent(*myutxo, height, tx)
if err != nil {
return err
return 0, err
}
// delete from my in-ram utxo set
t.Utxos = append(t.Utxos[:i], t.Utxos[i+1:]...)
}
}
}
log.Printf("%d hits, lost %d", hits, loss)
// log.Printf("%d hits, lost %d", hits, loss)
t.Sum -= loss
return nil
return hits, nil
}
// need this because before I was comparing pointers maybe?

@ -17,8 +17,9 @@ var (
BKTStxos = []byte("SpentTxs") // for bookkeeping
BKTTxns = []byte("Txns") // all txs we care about, for replays
BKTState = []byte("MiscState") // last state of DB
KEYNumKeys = []byte("NumKeys") // number of keys used
// these are in the state bucket
KEYNumKeys = []byte("NumKeys") // number of keys used
KEYTipHeight = []byte("TipHeight") // height synced to
)
func (ts *TxStore) OpenDB(filename string) error {
@ -75,8 +76,8 @@ func (ts *TxStore) NewAdr() (*btcutil.AddressPubKeyHash, error) {
// write to db file
err = ts.StateDB.Update(func(btx *bolt.Tx) error {
stt := btx.Bucket(BKTState)
return stt.Put(KEYNumKeys, buf.Bytes())
sta := btx.Bucket(BKTState)
return sta.Put(KEYNumKeys, buf.Bytes())
})
if err != nil {
return nil, err
@ -86,6 +87,44 @@ func (ts *TxStore) NewAdr() (*btcutil.AddressPubKeyHash, error) {
return newAdr, nil
}
// SetBDay sets the birthday (birth height) of the db (really keyfile)
func (ts *TxStore) SetDBSyncHeight(n int32) error {
var buf bytes.Buffer
_ = binary.Write(&buf, binary.BigEndian, n)
return ts.StateDB.Update(func(btx *bolt.Tx) error {
sta := btx.Bucket(BKTState)
return sta.Put(KEYTipHeight, buf.Bytes())
})
}
// SyncHeight returns the chain height to which the db has synced
func (ts *TxStore) GetDBSyncHeight() (int32, error) {
var n int32
err := ts.StateDB.View(func(btx *bolt.Tx) error {
sta := btx.Bucket(BKTState)
if sta == nil {
return fmt.Errorf("no state")
}
t := sta.Get(KEYTipHeight)
if t == nil { // no height written, so 0
return nil
}
// read 4 byte tip height to n
err := binary.Read(bytes.NewBuffer(t), binary.BigEndian, &n)
if err != nil {
return err
}
return nil
})
if err != nil {
return 0, err
}
return n, nil
}
// NumUtxos returns the number of utxos in the DB.
func (ts *TxStore) NumUtxos() (uint32, error) {
var n uint32
@ -125,17 +164,22 @@ func (ts *TxStore) PopulateAdrs(lastKey uint32) error {
// SaveToDB write a utxo to disk, overwriting an old utxo of the same outpoint
func (ts *TxStore) SaveUtxo(u *Utxo) error {
err := ts.StateDB.Update(func(btx *bolt.Tx) error {
b, err := u.ToBytes()
if err != nil {
return err
}
err = ts.StateDB.Update(func(btx *bolt.Tx) error {
duf := btx.Bucket(BKTUtxos)
b, err := u.ToBytes()
if err != nil {
return err
sta := btx.Bucket(BKTState)
// kindof hack, height is 36:40
// also not really tip height...
if u.AtHeight > 0 { // if confirmed
err = sta.Put(KEYTipHeight, b[36:40])
if err != nil {
return err
}
}
// don't check for dupes here, check in AbsorbTx(). here overwrite.
// if duf.Get(b[:36]) != nil { // already have tx
// dupe = true
// return nil
// }
// key : val is txid:everything else
return duf.Put(b[:36], b[36:])
@ -207,12 +251,12 @@ func (ts *TxStore) LoadFromDB() error {
if spent == nil {
return fmt.Errorf("no spenttx bucket")
}
state := btx.Bucket(BKTState)
if state == nil {
sta := btx.Bucket(BKTState)
if sta == nil {
return fmt.Errorf("no state bucket")
}
// first populate addresses from state bucket
numKeysBytes := state.Get(KEYNumKeys)
numKeysBytes := sta.Get(KEYNumKeys)
if numKeysBytes != nil { // NumKeys exists, read into uint32
buf := bytes.NewBuffer(numKeysBytes)
var numKeys uint32