fix dumb loop break error

break was outside the if bytes.Equal {}, oops.  only checked first output.
Works now.  Concurrency also seems OK but need to test more.
This commit is contained in:
Tadge Dryja 2016-01-31 02:08:39 -08:00
parent 3b774ef361
commit cf01e02d64
4 changed files with 27 additions and 148 deletions

@ -230,7 +230,7 @@ func (s *SPVCon) AskForTx(txid wire.ShaHash) {
// AskForBlock requests a merkle block we heard about from an inv message. // AskForBlock requests a merkle block we heard about from an inv message.
// We don't have it in our header file so when we get it we do both operations: // We don't have it in our header file so when we get it we do both operations:
// appending and checking the header, and checking spv proofs // appending and checking the header, and checking spv proofs
func (s *SPVCon) AskForBlock(hsh wire.ShaHash) { func (s *SPVCon) AskForBlockx(hsh wire.ShaHash) {
s.headerMutex.Lock() s.headerMutex.Lock()
defer s.headerMutex.Unlock() defer s.headerMutex.Unlock()
@ -493,7 +493,6 @@ func (s *SPVCon) AskForMerkBlocks(current, last int32) error {
// send filter // send filter
s.SendFilter(filt) s.SendFilter(filt)
fmt.Printf("sent filter %x\n", filt.MsgFilterLoad().Filter) fmt.Printf("sent filter %x\n", filt.MsgFilterLoad().Filter)
s.headerMutex.Lock() s.headerMutex.Lock()
defer s.headerMutex.Unlock() defer s.headerMutex.Unlock()

@ -37,9 +37,9 @@ func (s *SPVCon) incomingMessageHandler() {
log.Printf("Merkle block error: %s\n", err.Error()) log.Printf("Merkle block error: %s\n", err.Error())
continue continue
} }
case *wire.MsgHeaders: case *wire.MsgHeaders: // concurrent because we keep asking for blocks
go s.HeaderHandler(m) go s.HeaderHandler(m)
case *wire.MsgTx: // can't be concurrent! out of order kills case *wire.MsgTx: // not concurrent! txs must be in order
s.TxHandler(m) s.TxHandler(m)
case *wire.MsgReject: case *wire.MsgReject:
log.Printf("Rejected! cmd: %s code: %s tx: %s reason: %s", log.Printf("Rejected! cmd: %s code: %s tx: %s reason: %s",
@ -109,31 +109,34 @@ func (s *SPVCon) HeaderHandler(m *wire.MsgHeaders) {
} }
// if we got post DB syncheight headers, get merkleblocks for them // if we got post DB syncheight headers, get merkleblocks for them
// this is always true except for first pre-birthday sync // this is always true except for first pre-birthday sync
syncTip, err := s.TS.GetDBSyncHeight()
if err != nil { // checked header length, start req for more if needed
log.Printf("Header error: %s", err.Error()) if moar {
return s.AskForHeaders()
} } else { // no moar, done w/ headers, get merkleblocks
fmt.Printf("locks here...?? ")
s.headerMutex.Lock()
endPos, err := s.headerFile.Seek(0, os.SEEK_END) endPos, err := s.headerFile.Seek(0, os.SEEK_END)
if err != nil { if err != nil {
log.Printf("Header error: %s", err.Error()) log.Printf("Header error: %s", err.Error())
return return
} }
s.headerMutex.Unlock()
tip := int32(endPos/80) - 1 // move back 1 header length to read tip := int32(endPos/80) - 1 // move back 1 header length to read
syncTip, err := s.TS.GetDBSyncHeight()
// checked header lenght, start req for more if needed if err != nil {
if moar { log.Printf("syncTip error: %s", err.Error())
s.AskForHeaders() return
} }
if syncTip < tip { if syncTip < tip {
fmt.Printf("syncTip %d headerTip %d\n", syncTip, tip) fmt.Printf("syncTip %d headerTip %d\n", syncTip, tip)
err = s.AskForMerkBlocks(syncTip, tip) err = s.AskForMerkBlocks(syncTip+1, tip)
if err != nil { if err != nil {
log.Printf("AskForMerkBlocks error: %s", err.Error()) log.Printf("AskForMerkBlocks error: %s", err.Error())
return return
} }
} }
}
} }
func (s *SPVCon) TxHandler(m *wire.MsgTx) { func (s *SPVCon) TxHandler(m *wire.MsgTx) {
@ -164,7 +167,7 @@ func (s *SPVCon) InvHandler(m *wire.MsgInv) {
if len(s.mBlockQueue) == 0 { if len(s.mBlockQueue) == 0 {
// don't ask directly; instead ask for header // don't ask directly; instead ask for header
fmt.Printf("asking for headers due to inv block\n") fmt.Printf("asking for headers due to inv block\n")
// s.AskForHeaders() s.AskForHeaders()
} else { } else {
fmt.Printf("inv block but ignoring, not synched\n") fmt.Printf("inv block but ignoring, not synched\n")
} }

@ -1,12 +1,9 @@
package uspv package uspv
import ( import (
"bytes"
"fmt" "fmt"
"log" "log"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/boltdb/bolt" "github.com/boltdb/bolt"
@ -112,127 +109,6 @@ func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
return f, nil return f, nil
} }
// Ingest a tx into wallet, dealing with both gains and losses
func (t *TxStore) AckTxz(tx *wire.MsgTx) (uint32, error) {
var ioHits uint32 // number of utxos changed due to this tx
inTxid := tx.TxSha()
height, ok := t.OKTxids[inTxid]
if !ok {
log.Printf("%s", TxToString(tx))
return 0, fmt.Errorf("tx %s not in OKTxids.", inTxid.String())
}
delete(t.OKTxids, inTxid) // don't need anymore
hitsGained, err := t.AbsorbTx(tx, height)
if err != nil {
return 0, err
}
hitsLost, err := t.ExpellTx(tx, height)
if err != nil {
return 0, err
}
ioHits = hitsGained + hitsLost
// fmt.Printf("ingested tx %s total amt %d\n", inTxid.String(), t.Sum)
return ioHits, nil
}
// AbsorbTx Absorbs money into wallet from a tx. returns number of
// new utxos absorbed.
func (t *TxStore) AbsorbTx(tx *wire.MsgTx, height int32) (uint32, error) {
if tx == nil {
return 0, fmt.Errorf("Tried to add nil tx")
}
newTxid := tx.TxSha()
var hits uint32 // how many outputs of this tx are ours
var acq int64 // total acquirement from this tx
// check if any of the tx's outputs match my known outpoints
for i, out := range tx.TxOut { // in each output of tx
dup := false // start by assuming its new until found duplicate
newOp := wire.NewOutPoint(&newTxid, uint32(i))
// first look for dupes -- already known outpoints.
// if we find a dupe here overwrite it to the DB.
for _, u := range t.Utxos {
dup = OutPointsEqual(*newOp, u.Op) // is this outpoint known?
if dup { // found dupe
fmt.Printf(" %s is dupe\t", newOp.String())
hits++ // thought a dupe, still a hit
u.AtHeight = height // ONLY difference is height
// save modified utxo to db, overwriting old one
err := t.SaveUtxo(u)
if err != nil {
return 0, err
}
break // out of the t.Utxo range loop
}
}
if dup {
// if we found the outpoint to be a dup above, don't add it again
// when it matches an address, just go to the next outpoint
continue
}
// check if this is a new txout matching one of my addresses
for _, a := range t.Adrs { // compare to each adr we have
// check for full script to eliminate false positives
aPKscript, err := txscript.PayToAddrScript(a.PkhAdr)
if err != nil {
return 0, err
}
if bytes.Equal(out.PkScript, aPKscript) { // hit
// already checked for dupes, so this must be a new outpoint
var newu Utxo
newu.AtHeight = height
newu.KeyIdx = a.KeyIdx
newu.Value = out.Value
var newop wire.OutPoint
newop.Hash = tx.TxSha()
newop.Index = uint32(i)
newu.Op = newop
err = t.SaveUtxo(&newu)
if err != nil {
return 0, err
}
acq += out.Value
hits++
t.Utxos = append(t.Utxos, &newu) // always add new utxo
break
}
}
}
// log.Printf("%d hits, acquired %d", hits, acq)
t.Sum += acq
return hits, nil
}
// Expell money from wallet due to a tx
func (t *TxStore) ExpellTx(tx *wire.MsgTx, height int32) (uint32, error) {
if tx == nil {
return 0, fmt.Errorf("Tried to add nil tx")
}
var hits uint32
var loss int64
for _, in := range tx.TxIn {
for i, myutxo := range t.Utxos {
if OutPointsEqual(myutxo.Op, in.PreviousOutPoint) {
hits++
loss += myutxo.Value
err := t.MarkSpent(*myutxo, height, tx)
if err != nil {
return 0, err
}
// delete from my in-ram utxo set
t.Utxos = append(t.Utxos[:i], t.Utxos[i+1:]...)
}
}
}
// log.Printf("%d hits, lost %d", hits, loss)
t.Sum -= loss
return hits, nil
}
// need this because before I was comparing pointers maybe? // need this because before I was comparing pointers maybe?
// so they were the same outpoint but stored in 2 places so false negative? // so they were the same outpoint but stored in 2 places so false negative?
func OutPointsEqual(a, b wire.OutPoint) bool { func OutPointsEqual(a, b wire.OutPoint) bool {

@ -246,9 +246,10 @@ func (ts *TxStore) Ingest(tx *wire.MsgTx) (uint32, error) {
nUtxoBytes = append(nUtxoBytes, b) nUtxoBytes = append(nUtxoBytes, b)
ts.Sum += newu.Value ts.Sum += newu.Value
hits++ hits++
}
break // only one match break // only one match
} }
}
} }
err = ts.StateDB.Update(func(btx *bolt.Tx) error { err = ts.StateDB.Update(func(btx *bolt.Tx) error {