reorganize lots of files, add rebroadcast

move methods to new files to keep things a bit organized.
add rebroadcast of unconfirmed txs after sync
mutex on OKtxid map
deal with doublespends next
This commit is contained in:
Tadge Dryja 2016-02-05 01:16:45 -08:00
parent 9eccb0638a
commit 25d90f5345
7 changed files with 390 additions and 305 deletions

@ -302,7 +302,7 @@ func SendCoins(s uspv.SPVCon, adr btcutil.Address, sendAmt int64) error {
// send it out on the wire. hope it gets there. // send it out on the wire. hope it gets there.
// we should deal with rejects. Don't yet. // we should deal with rejects. Don't yet.
err = s.PushTx(tx) err = s.NewOutgoingTx(tx)
if err != nil { if err != nil {
return err return err
} }

@ -1,17 +1,13 @@
package uspv package uspv
import ( import (
"bytes"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"os" "os"
"sync" "sync"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/bloom"
) )
const ( const (
@ -53,172 +49,6 @@ type SPVCon struct {
inWaitState chan bool inWaitState chan bool
} }
func OpenSPV(remoteNode string, hfn, dbfn string,
inTs *TxStore, p *chaincfg.Params) (SPVCon, error) {
// create new SPVCon
var s SPVCon
// I should really merge SPVCon and TxStore, they're basically the same
inTs.Param = p
s.TS = inTs // copy pointer of txstore into spvcon
// open header file
err := s.openHeaderFile(hfn)
if err != nil {
return s, err
}
// open TCP connection
s.con, err = net.Dial("tcp", remoteNode)
if err != nil {
return s, err
}
// assign version bits for local node
s.localVersion = VERSION
// transaction store for this SPV connection
err = inTs.OpenDB(dbfn)
if err != nil {
return s, err
}
myMsgVer, err := wire.NewMsgVersionFromConn(s.con, 0, 0)
if err != nil {
return s, err
}
err = myMsgVer.AddUserAgent("test", "zero")
if err != nil {
return s, err
}
// must set this to enable SPV stuff
myMsgVer.AddService(wire.SFNodeBloom)
// this actually sends
n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.TS.Param.Net)
if err != nil {
return s, err
}
s.WBytes += uint64(n)
log.Printf("wrote %d byte version message to %s\n",
n, s.con.RemoteAddr().String())
n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.TS.Param.Net)
if err != nil {
return s, err
}
s.RBytes += uint64(n)
log.Printf("got %d byte response %x\n command: %s\n", n, b, m.Command())
mv, ok := m.(*wire.MsgVersion)
if ok {
log.Printf("connected to %s", mv.UserAgent)
}
log.Printf("remote reports version %x (dec %d)\n",
mv.ProtocolVersion, mv.ProtocolVersion)
// set remote height
s.remoteHeight = mv.LastBlock
mva := wire.NewMsgVerAck()
n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.TS.Param.Net)
if err != nil {
return s, err
}
s.WBytes += uint64(n)
s.inMsgQueue = make(chan wire.Message)
go s.incomingMessageHandler()
s.outMsgQueue = make(chan wire.Message)
go s.outgoingMessageHandler()
s.mBlockQueue = make(chan HashAndHeight, 32) // queue depth 32 is a thing
s.fPositives = make(chan int32, 4000) // a block full, approx
s.inWaitState = make(chan bool, 1)
go s.fPositiveHandler()
return s, nil
}
func (s *SPVCon) openHeaderFile(hfn string) error {
_, err := os.Stat(hfn)
if err != nil {
if os.IsNotExist(err) {
var b bytes.Buffer
err = s.TS.Param.GenesisBlock.Header.Serialize(&b)
if err != nil {
return err
}
err = ioutil.WriteFile(hfn, b.Bytes(), 0600)
if err != nil {
return err
}
log.Printf("created hardcoded genesis header at %s\n",
hfn)
}
}
s.headerFile, err = os.OpenFile(hfn, os.O_RDWR, 0600)
if err != nil {
return err
}
log.Printf("opened header file %s\n", s.headerFile.Name())
return nil
}
func (s *SPVCon) PongBack(nonce uint64) {
mpong := wire.NewMsgPong(nonce)
s.outMsgQueue <- mpong
return
}
func (s *SPVCon) SendFilter(f *bloom.Filter) {
s.outMsgQueue <- f.MsgFilterLoad()
return
}
// HeightFromHeader gives you the block height given a 80 byte block header
// seems like looking for the merkle root is the best way to do this
func (s *SPVCon) HeightFromHeader(query wire.BlockHeader) (uint32, error) {
// start from the most recent and work back in time; even though that's
// kind of annoying it's probably a lot faster since things tend to have
// happened recently.
// seek to last header
s.headerMutex.Lock()
defer s.headerMutex.Unlock()
lastPos, err := s.headerFile.Seek(-80, os.SEEK_END)
if err != nil {
return 0, err
}
height := lastPos / 80
var current wire.BlockHeader
for height > 0 {
// grab header from disk
err = current.Deserialize(s.headerFile)
if err != nil {
return 0, err
}
// check if merkle roots match
if current.MerkleRoot.IsEqual(&query.MerkleRoot) {
// if they do, great, return height
return uint32(height), nil
}
// skip back one header (2 because we just read one)
_, err = s.headerFile.Seek(-160, os.SEEK_CUR)
if err != nil {
return 0, err
}
// decrement height
height--
}
// finished for loop without finding match
return 0, fmt.Errorf("Header not found on disk")
}
// AskForTx requests a tx we heard about from an inv message. // AskForTx requests a tx we heard about from an inv message.
// It's one at a time but should be fast enough. // It's one at a time but should be fast enough.
// I don't like this function because SPV shouldn't even ask... // I don't like this function because SPV shouldn't even ask...
@ -229,49 +59,34 @@ func (s *SPVCon) AskForTx(txid wire.ShaHash) {
s.outMsgQueue <- gdata s.outMsgQueue <- gdata
} }
func (s *SPVCon) AskForHeaders() error { // HashAndHeight is needed instead of just height in case a fullnode
var hdr wire.BlockHeader // responds abnormally (?) by sending out of order merkleblocks.
ghdr := wire.NewMsgGetHeaders() // we cache a merkleroot:height pair in the queue so we don't have to
ghdr.ProtocolVersion = s.localVersion // look them up from the disk.
// Also used when inv messages indicate blocks so we can add the header
// and parse the txs in one request instead of requesting headers first.
type HashAndHeight struct {
blockhash wire.ShaHash
height int32
final bool // indicates this is the last merkleblock requested
}
s.headerMutex.Lock() // start header file ops // NewRootAndHeight saves like 2 lines.
info, err := s.headerFile.Stat() func NewRootAndHeight(b wire.ShaHash, h int32) (hah HashAndHeight) {
hah.blockhash = b
hah.height = h
return
}
func (s *SPVCon) RemoveHeaders(r int32) error {
endPos, err := s.headerFile.Seek(0, os.SEEK_END)
if err != nil { if err != nil {
return err return err
} }
headerFileSize := info.Size() err = s.headerFile.Truncate(endPos - int64(r*80))
if headerFileSize == 0 || headerFileSize%80 != 0 { // header file broken
return fmt.Errorf("Header file not a multiple of 80 bytes")
}
// seek to 80 bytes from end of file
ns, err := s.headerFile.Seek(-80, os.SEEK_END)
if err != nil { if err != nil {
log.Printf("can't seek\n") return fmt.Errorf("couldn't truncate header file")
return err
} }
log.Printf("suk to offset %d (should be near the end\n", ns)
// get header from last 80 bytes of file
err = hdr.Deserialize(s.headerFile)
if err != nil {
log.Printf("can't Deserialize")
return err
}
s.headerMutex.Unlock() // done with header file
cHash := hdr.BlockSha()
err = ghdr.AddBlockLocatorHash(&cHash)
if err != nil {
return err
}
fmt.Printf("get headers message has %d header hashes, first one is %s\n",
len(ghdr.BlockLocatorHashes), ghdr.BlockLocatorHashes[0].String())
s.outMsgQueue <- ghdr
return nil return nil
} }
@ -301,7 +116,9 @@ func (s *SPVCon) IngestMerkleBlock(m *wire.MsgMerkleBlock) error {
return fmt.Errorf("Txid store error: %s\n", err.Error()) return fmt.Errorf("Txid store error: %s\n", err.Error())
} }
} }
// write to db that we've sync'd to the height indicated in the
// merkle block. This isn't QUITE true since we haven't actually gotten
// the txs yet but if there are problems with the txs we should backtrack.
err = s.TS.SetDBSyncHeight(hah.height) err = s.TS.SetDBSyncHeight(hah.height)
if err != nil { if err != nil {
return err return err
@ -406,59 +223,49 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) {
return true, nil return true, nil
} }
// HashAndHeight is needed instead of just height in case a fullnode func (s *SPVCon) AskForHeaders() error {
// responds abnormally (?) by sending out of order merkleblocks. var hdr wire.BlockHeader
// we cache a merkleroot:height pair in the queue so we don't have to ghdr := wire.NewMsgGetHeaders()
// look them up from the disk. ghdr.ProtocolVersion = s.localVersion
// Also used when inv messages indicate blocks so we can add the header
// and parse the txs in one request instead of requesting headers first.
type HashAndHeight struct {
blockhash wire.ShaHash
height int32
final bool // indicates this is the last merkleblock requested
}
// NewRootAndHeight saves like 2 lines. s.headerMutex.Lock() // start header file ops
func NewRootAndHeight(b wire.ShaHash, h int32) (hah HashAndHeight) { info, err := s.headerFile.Stat()
hah.blockhash = b
hah.height = h
return
}
func (s *SPVCon) PushTx(tx *wire.MsgTx) error {
txid := tx.TxSha()
err := s.TS.AddTxid(&txid, 0)
if err != nil { if err != nil {
return err return err
} }
_, err = s.TS.Ingest(tx) // our own tx so don't need to track relevance headerFileSize := info.Size()
if headerFileSize == 0 || headerFileSize%80 != 0 { // header file broken
return fmt.Errorf("Header file not a multiple of 80 bytes")
}
// seek to 80 bytes from end of file
ns, err := s.headerFile.Seek(-80, os.SEEK_END)
if err != nil {
log.Printf("can't seek\n")
return err
}
log.Printf("suk to offset %d (should be near the end\n", ns)
// get header from last 80 bytes of file
err = hdr.Deserialize(s.headerFile)
if err != nil {
log.Printf("can't Deserialize")
return err
}
s.headerMutex.Unlock() // done with header file
cHash := hdr.BlockSha()
err = ghdr.AddBlockLocatorHash(&cHash)
if err != nil { if err != nil {
return err return err
} }
s.outMsgQueue <- tx
return nil
}
func (s *SPVCon) GetNextHeaderHeight() (int32, error) { fmt.Printf("get headers message has %d header hashes, first one is %s\n",
s.headerMutex.Lock() len(ghdr.BlockLocatorHashes), ghdr.BlockLocatorHashes[0].String())
defer s.headerMutex.Unlock()
info, err := s.headerFile.Stat() // get s.outMsgQueue <- ghdr
if err != nil {
return 0, err // crash if header file disappears
}
nextHeight := int32(info.Size() / 80)
return nextHeight, nil
}
func (s *SPVCon) RemoveHeaders(r int32) error {
endPos, err := s.headerFile.Seek(0, os.SEEK_END)
if err != nil {
return err
}
err = s.headerFile.Truncate(endPos - int64(r*80))
if err != nil {
return fmt.Errorf("couldn't truncate header file")
}
return nil return nil
} }
@ -487,6 +294,8 @@ func (s *SPVCon) AskForMerkBlocks() error {
// nothing to ask for; set wait state and return // nothing to ask for; set wait state and return
fmt.Printf("no merkle blocks to request, entering wait state\n") fmt.Printf("no merkle blocks to request, entering wait state\n")
s.inWaitState <- true s.inWaitState <- true
// also advertise any unconfirmed txs here
s.Rebroadcast()
return nil return nil
} }
@ -535,6 +344,5 @@ func (s *SPVCon) AskForMerkBlocks() error {
s.mBlockQueue <- hah // push height and mroot of requested block on queue s.mBlockQueue <- hah // push height and mroot of requested block on queue
dbTip++ dbTip++
} }
return nil return nil
} }

117
uspv/init.go Normal file

@ -0,0 +1,117 @@
package uspv
import (
"bytes"
"io/ioutil"
"log"
"net"
"os"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire"
)
// OpenPV starts a
func OpenSPV(remoteNode string, hfn, dbfn string,
inTs *TxStore, p *chaincfg.Params) (SPVCon, error) {
// create new SPVCon
var s SPVCon
// I should really merge SPVCon and TxStore, they're basically the same
inTs.Param = p
s.TS = inTs // copy pointer of txstore into spvcon
// open header file
err := s.openHeaderFile(hfn)
if err != nil {
return s, err
}
// open TCP connection
s.con, err = net.Dial("tcp", remoteNode)
if err != nil {
return s, err
}
// assign version bits for local node
s.localVersion = VERSION
// transaction store for this SPV connection
err = inTs.OpenDB(dbfn)
if err != nil {
return s, err
}
myMsgVer, err := wire.NewMsgVersionFromConn(s.con, 0, 0)
if err != nil {
return s, err
}
err = myMsgVer.AddUserAgent("test", "zero")
if err != nil {
return s, err
}
// must set this to enable SPV stuff
myMsgVer.AddService(wire.SFNodeBloom)
// this actually sends
n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.TS.Param.Net)
if err != nil {
return s, err
}
s.WBytes += uint64(n)
log.Printf("wrote %d byte version message to %s\n",
n, s.con.RemoteAddr().String())
n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.TS.Param.Net)
if err != nil {
return s, err
}
s.RBytes += uint64(n)
log.Printf("got %d byte response %x\n command: %s\n", n, b, m.Command())
mv, ok := m.(*wire.MsgVersion)
if ok {
log.Printf("connected to %s", mv.UserAgent)
}
log.Printf("remote reports version %x (dec %d)\n",
mv.ProtocolVersion, mv.ProtocolVersion)
// set remote height
s.remoteHeight = mv.LastBlock
mva := wire.NewMsgVerAck()
n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.TS.Param.Net)
if err != nil {
return s, err
}
s.WBytes += uint64(n)
s.inMsgQueue = make(chan wire.Message)
go s.incomingMessageHandler()
s.outMsgQueue = make(chan wire.Message)
go s.outgoingMessageHandler()
s.mBlockQueue = make(chan HashAndHeight, 32) // queue depth 32 is a thing
s.fPositives = make(chan int32, 4000) // a block full, approx
s.inWaitState = make(chan bool, 1)
go s.fPositiveHandler()
return s, nil
}
func (s *SPVCon) openHeaderFile(hfn string) error {
_, err := os.Stat(hfn)
if err != nil {
if os.IsNotExist(err) {
var b bytes.Buffer
err = s.TS.Param.GenesisBlock.Header.Serialize(&b)
if err != nil {
return err
}
err = ioutil.WriteFile(hfn, b.Bytes(), 0600)
if err != nil {
return err
}
log.Printf("created hardcoded genesis header at %s\n",
hfn)
}
}
s.headerFile, err = os.OpenFile(hfn, os.O_RDWR, 0600)
if err != nil {
return err
}
log.Printf("opened header file %s\n", s.headerFile.Name())
return nil
}

@ -50,7 +50,8 @@ func (s *SPVCon) incomingMessageHandler() {
for i, thing := range m.InvList { for i, thing := range m.InvList {
log.Printf("\t$d) %s: %s", i, thing.Type, thing.Hash) log.Printf("\t$d) %s: %s", i, thing.Type, thing.Hash)
} }
case *wire.MsgGetData:
s.GetDataHandler(m)
default: default:
log.Printf("Got unknown message type %s\n", m.Command()) log.Printf("Got unknown message type %s\n", m.Command())
} }
@ -114,7 +115,6 @@ func (s *SPVCon) HeaderHandler(m *wire.MsgHeaders) {
} }
return return
} }
// no moar, done w/ headers, get merkleblocks // no moar, done w/ headers, get merkleblocks
err = s.AskForMerkBlocks() err = s.AskForMerkBlocks()
if err != nil { if err != nil {
@ -123,19 +123,52 @@ func (s *SPVCon) HeaderHandler(m *wire.MsgHeaders) {
} }
} }
// TxHandler takes in transaction messages that come in from either a request
// after an inv message or after a merkle block message.
func (s *SPVCon) TxHandler(m *wire.MsgTx) { func (s *SPVCon) TxHandler(m *wire.MsgTx) {
hits, err := s.TS.Ingest(m) s.TS.OKMutex.Lock()
height, ok := s.TS.OKTxids[m.TxSha()]
s.TS.OKMutex.Unlock()
if !ok {
log.Printf("Tx %s unknown, will not ingest\n")
return
}
hits, err := s.TS.Ingest(m, height)
if err != nil { if err != nil {
log.Printf("Incoming Tx error: %s\n", err.Error()) log.Printf("Incoming Tx error: %s\n", err.Error())
return
} }
if hits == 0 { if hits == 0 {
log.Printf("tx %s had no hits, filter false positive.", log.Printf("tx %s had no hits, filter false positive.",
m.TxSha().String()) m.TxSha().String())
s.fPositives <- 1 // add one false positive to chan s.fPositives <- 1 // add one false positive to chan
} else { return
log.Printf("tx %s ingested and matches %d utxo/adrs.",
m.TxSha().String(), hits)
} }
log.Printf("tx %s ingested and matches %d utxo/adrs.",
m.TxSha().String(), hits)
}
// GetDataHandler responds to requests for tx data, which happen after
// advertising our txs via an inv message
func (s *SPVCon) GetDataHandler(m *wire.MsgGetData) {
log.Printf("got GetData. Contains:\n")
var sent int32
for i, thing := range m.InvList {
log.Printf("\t%d)%s : %s",
i, thing.Type.String(), thing.Hash.String())
if thing.Type != wire.InvTypeTx { // refuse non-tx reqs
log.Printf("We only respond to tx requests, ignoring")
continue
}
tx, err := s.TS.GetTx(&thing.Hash)
if err != nil {
log.Printf("error getting tx %s: %s",
thing.Hash.String(), err.Error())
}
s.outMsgQueue <- tx
sent++
}
log.Printf("sent %d of %d requested items", sent, len(m.InvList))
} }
func (s *SPVCon) InvHandler(m *wire.MsgInv) { func (s *SPVCon) InvHandler(m *wire.MsgInv) {
@ -143,8 +176,8 @@ func (s *SPVCon) InvHandler(m *wire.MsgInv) {
for i, thing := range m.InvList { for i, thing := range m.InvList {
log.Printf("\t%d)%s : %s", log.Printf("\t%d)%s : %s",
i, thing.Type.String(), thing.Hash.String()) i, thing.Type.String(), thing.Hash.String())
if thing.Type == wire.InvTypeTx { // new tx, ingest if thing.Type == wire.InvTypeTx { // new tx, OK it at 0 and request
s.TS.OKTxids[thing.Hash] = 0 // unconfirmed s.TS.AddTxid(&thing.Hash, 0) // unconfirmed
s.AskForTx(thing.Hash) s.AskForTx(thing.Hash)
} }
if thing.Type == wire.InvTypeBlock { // new block what to do? if thing.Type == wire.InvTypeBlock { // new block what to do?

@ -3,13 +3,66 @@ package uspv
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"log"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/bloom"
"github.com/btcsuite/btcutil/hdkeychain" "github.com/btcsuite/btcutil/hdkeychain"
"github.com/btcsuite/btcutil/txsort" "github.com/btcsuite/btcutil/txsort"
) )
func (s *SPVCon) PongBack(nonce uint64) {
mpong := wire.NewMsgPong(nonce)
s.outMsgQueue <- mpong
return
}
func (s *SPVCon) SendFilter(f *bloom.Filter) {
s.outMsgQueue <- f.MsgFilterLoad()
return
}
// Rebroadcast sends an inv message of all the unconfirmed txs the db is
// aware of. This is called after every sync. Only txids so hopefully not
// too annoying for nodes.
func (s *SPVCon) Rebroadcast() {
// get all unconfirmed txs
invMsg, err := s.TS.GetPendingInv()
if err != nil {
log.Printf("Rebroadcast error: %s", err.Error())
}
if len(invMsg.InvList) == 0 { // nothing to broadcast, so don't
return
}
s.outMsgQueue <- invMsg
return
}
func (s *SPVCon) NewOutgoingTx(tx *wire.MsgTx) error {
txid := tx.TxSha()
// assign height of zero for txs we create
err := s.TS.AddTxid(&txid, 0)
if err != nil {
return err
}
_, err = s.TS.Ingest(tx, 0) // our own tx; don't keep track of false positives
if err != nil {
return err
}
// make an inv message instead of a tx message to be polite
iv1 := wire.NewInvVect(wire.InvTypeTx, &txid)
invMsg := wire.NewMsgInv()
err = invMsg.AddInvVect(iv1)
if err != nil {
return err
}
s.outMsgQueue <- invMsg
return nil
}
func (t *TxStore) SignThis(tx *wire.MsgTx) error { func (t *TxStore) SignThis(tx *wire.MsgTx) error {
fmt.Printf("-= SignThis =-\n") fmt.Printf("-= SignThis =-\n")

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"log" "log"
"sync"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
@ -17,6 +18,7 @@ import (
type TxStore struct { type TxStore struct {
OKTxids map[wire.ShaHash]int32 // known good txids and their heights OKTxids map[wire.ShaHash]int32 // known good txids and their heights
OKMutex sync.Mutex
Adrs []MyAdr // endeavouring to acquire capital Adrs []MyAdr // endeavouring to acquire capital
StateDB *bolt.DB // place to write all this down StateDB *bolt.DB // place to write all this down
@ -67,7 +69,9 @@ func (t *TxStore) AddTxid(txid *wire.ShaHash, height int32) error {
return fmt.Errorf("tried to add nil txid") return fmt.Errorf("tried to add nil txid")
} }
log.Printf("added %s to OKTxids at height %d\n", txid.String(), height) log.Printf("added %s to OKTxids at height %d\n", txid.String(), height)
t.OKMutex.Lock()
t.OKTxids[*txid] = height t.OKTxids[*txid] = height
t.OKMutex.Unlock()
return nil return nil
} }
@ -76,17 +80,6 @@ func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
if len(t.Adrs) == 0 { if len(t.Adrs) == 0 {
return nil, fmt.Errorf("no addresses to filter for") return nil, fmt.Errorf("no addresses to filter for")
} }
// add addresses to look for incoming
nutxo, err := t.NumUtxos()
if err != nil {
return nil, err
}
elem := uint32(len(t.Adrs)) + nutxo
f := bloom.NewFilter(elem, 0, 0.000001, wire.BloomUpdateAll)
for _, a := range t.Adrs {
f.Add(a.PkhAdr.ScriptAddress())
}
// get all utxos to add outpoints to filter // get all utxos to add outpoints to filter
allUtxos, err := t.GetAllUtxos() allUtxos, err := t.GetAllUtxos()
@ -94,6 +87,12 @@ func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
return nil, err return nil, err
} }
elem := uint32(len(t.Adrs) + len(allUtxos))
f := bloom.NewFilter(elem, 0, 0.000001, wire.BloomUpdateAll)
for _, a := range t.Adrs {
f.Add(a.PkhAdr.ScriptAddress())
}
for _, u := range allUtxos { for _, u := range allUtxos {
f.AddOutPoint(&u.Op) f.AddOutPoint(&u.Op)
} }

@ -121,7 +121,8 @@ func (ts *TxStore) NewAdr() (*btcutil.AddressPubKeyHash, error) {
return newAdr, nil return newAdr, nil
} }
// SetBDay sets the birthday (birth height) of the db (really keyfile) // SetDBSyncHeight sets sync height of the db, indicated the latest block
// of which it has ingested all the transactions.
func (ts *TxStore) SetDBSyncHeight(n int32) error { func (ts *TxStore) SetDBSyncHeight(n int32) error {
var buf bytes.Buffer var buf bytes.Buffer
_ = binary.Write(&buf, binary.BigEndian, n) _ = binary.Write(&buf, binary.BigEndian, n)
@ -159,24 +160,7 @@ func (ts *TxStore) GetDBSyncHeight() (int32, error) {
return n, nil return n, nil
} }
// NumUtxos returns the number of utxos in the DB. // GetAllUtxos returns a slice of all utxos known to the db. empty slice is OK.
func (ts *TxStore) NumUtxos() (uint32, error) {
var n uint32
err := ts.StateDB.View(func(btx *bolt.Tx) error {
duf := btx.Bucket(BKTUtxos)
if duf == nil {
return fmt.Errorf("no duffel bag")
}
stats := duf.Stats()
n = uint32(stats.KeyN)
return nil
})
if err != nil {
return 0, err
}
return n, nil
}
func (ts *TxStore) GetAllUtxos() ([]*Utxo, error) { func (ts *TxStore) GetAllUtxos() ([]*Utxo, error) {
var utxos []*Utxo var utxos []*Utxo
err := ts.StateDB.View(func(btx *bolt.Tx) error { err := ts.StateDB.View(func(btx *bolt.Tx) error {
@ -198,10 +182,8 @@ func (ts *TxStore) GetAllUtxos() ([]*Utxo, error) {
} }
// and add it to ram // and add it to ram
utxos = append(utxos, &newU) utxos = append(utxos, &newU)
return nil return nil
}) })
return nil return nil
}) })
if err != nil { if err != nil {
@ -210,6 +192,106 @@ func (ts *TxStore) GetAllUtxos() ([]*Utxo, error) {
return utxos, nil return utxos, nil
} }
// GetAllStxos returns a slice of all stxos known to the db. empty slice is OK.
func (ts *TxStore) GetAllStxos() ([]*Stxo, error) {
// this is almost the same as GetAllUtxos but whatever, it'd be more
// complicated to make one contain the other or something
var stxos []*Stxo
err := ts.StateDB.View(func(btx *bolt.Tx) error {
old := btx.Bucket(BKTStxos)
if old == nil {
return fmt.Errorf("no old txos")
}
return old.ForEach(func(k, v []byte) error {
// have to copy k and v here, otherwise append will crash it.
// not quite sure why but append does weird stuff I guess.
// create a new stxo
x := make([]byte, len(k)+len(v))
copy(x, k)
copy(x[len(k):], v)
newS, err := StxoFromBytes(x)
if err != nil {
return err
}
// and add it to ram
stxos = append(stxos, &newS)
return nil
})
return nil
})
if err != nil {
return nil, err
}
return stxos, nil
}
// GetTx takes a txid and returns the transaction. If we have it.
func (ts *TxStore) GetTx(txid *wire.ShaHash) (*wire.MsgTx, error) {
rtx := wire.NewMsgTx()
err := ts.StateDB.View(func(btx *bolt.Tx) error {
txns := btx.Bucket(BKTTxns)
if txns == nil {
return fmt.Errorf("no transactions in db")
}
txbytes := txns.Get(txid.Bytes())
if txbytes == nil {
return fmt.Errorf("tx %x not in db", txid.String())
}
buf := bytes.NewBuffer(txbytes)
return rtx.Deserialize(buf)
})
if err != nil {
return nil, err
}
return rtx, nil
}
// GetPendingInv returns an inv message containing all txs known to the
// db which are at height 0 (not known to be confirmed).
// This can be useful on startup or to rebroadcast unconfirmed txs.
func (ts *TxStore) GetPendingInv() (*wire.MsgInv, error) {
// use a map (really a set) do avoid dupes
txidMap := make(map[wire.ShaHash]struct{})
utxos, err := ts.GetAllUtxos() // get utxos from db
if err != nil {
return nil, err
}
stxos, err := ts.GetAllStxos() // get stxos from db
if err != nil {
return nil, err
}
// iterate through utxos, adding txids of anything with height 0
for _, utxo := range utxos {
if utxo.AtHeight == 0 {
txidMap[utxo.Op.Hash] = struct{}{} // adds to map
}
}
// do the same with stxos based on height at which spent
for _, stxo := range stxos {
if stxo.SpendHeight == 0 {
txidMap[stxo.SpendTxid] = struct{}{}
}
}
invMsg := wire.NewMsgInv()
for txid := range txidMap {
item := wire.NewInvVect(wire.InvTypeTx, &txid)
err = invMsg.AddInvVect(item)
if err != nil {
if err != nil {
return nil, err
}
}
}
// return inv message with all txids (maybe none)
return invMsg, nil
}
// PopulateAdrs just puts a bunch of adrs in ram; it doesn't touch the DB // PopulateAdrs just puts a bunch of adrs in ram; it doesn't touch the DB
func (ts *TxStore) PopulateAdrs(lastKey uint32) error { func (ts *TxStore) PopulateAdrs(lastKey uint32) error {
for k := uint32(0); k < lastKey; k++ { for k := uint32(0); k < lastKey; k++ {
@ -227,27 +309,18 @@ func (ts *TxStore) PopulateAdrs(lastKey uint32) error {
ma.PkhAdr = newAdr ma.PkhAdr = newAdr
ma.KeyIdx = k ma.KeyIdx = k
ts.Adrs = append(ts.Adrs, ma) ts.Adrs = append(ts.Adrs, ma)
} }
return nil return nil
} }
// Ingest puts a tx into the DB atomically. This can result in a // Ingest puts a tx into the DB atomically. This can result in a
// gain, a loss, or no result. Gain or loss in satoshis is returned. // gain, a loss, or no result. Gain or loss in satoshis is returned.
func (ts *TxStore) Ingest(tx *wire.MsgTx) (uint32, error) { func (ts *TxStore) Ingest(tx *wire.MsgTx, height int32) (uint32, error) {
var hits uint32 var hits uint32
var err error var err error
var spentOPs [][]byte var spentOPs [][]byte
var nUtxoBytes [][]byte var nUtxoBytes [][]byte
// first check that we have a height and tx has been SPV OK'd
inTxid := tx.TxSha()
height, ok := ts.OKTxids[inTxid]
if !ok {
return hits, fmt.Errorf("Ingest error: tx %s not in OKTxids.",
inTxid.String())
}
// tx has been OK'd by SPV; check tx sanity // tx has been OK'd by SPV; check tx sanity
utilTx := btcutil.NewTx(tx) // convert for validation utilTx := btcutil.NewTx(tx) // convert for validation
// checks stuff like inputs >= ouputs // checks stuff like inputs >= ouputs
@ -292,7 +365,6 @@ func (ts *TxStore) Ingest(tx *wire.MsgTx) (uint32, error) {
hits++ hits++
break // only one match break // only one match
} }
} }
} }
@ -302,6 +374,9 @@ func (ts *TxStore) Ingest(tx *wire.MsgTx) (uint32, error) {
// sta := btx.Bucket(BKTState) // sta := btx.Bucket(BKTState)
old := btx.Bucket(BKTStxos) old := btx.Bucket(BKTStxos)
txns := btx.Bucket(BKTTxns) txns := btx.Bucket(BKTTxns)
if duf == nil || old == nil || txns == nil {
return fmt.Errorf("error: db not initialized")
}
// first see if we lose utxos // first see if we lose utxos
// iterate through duffel bag and look for matches // iterate through duffel bag and look for matches