make SPVCon version agnostic

This commit is contained in:
Tadge Dryja 2016-01-18 23:43:41 -08:00
parent 26d9ae7f2b
commit 229e34b326
2 changed files with 29 additions and 36 deletions

@ -18,12 +18,9 @@ const (
headerFileName = "headers.bin"
// Except hash-160s, those aren't backwards. But anything that's 32 bytes is.
// because, cmon, 32? Gotta reverse that. But 20? 20 is OK.
NETVERSION = wire.TestNet3
VERSION = 70011
)
var (
params = &chaincfg.TestNet3Params
// version hardcoded for now, probably ok...?
VERSION = 70011
)
type SPVCon struct {
@ -33,7 +30,6 @@ type SPVCon struct {
localVersion uint32 // version we report
remoteVersion uint32 // version remote node
remoteHeight int32 // block height they're on
netType wire.BitcoinNet
// what's the point of the input queue? remove? leave for now...
inMsgQueue chan wire.Message // Messages coming in from remote node
@ -42,50 +38,59 @@ type SPVCon struct {
WBytes uint64 // total bytes written
RBytes uint64 // total bytes read
TS *TxStore
TS *TxStore
param *chaincfg.Params
}
func (s *SPVCon) Open(remoteNode string, hfn string, inTs *TxStore) error {
func OpenSPV(remoteNode string, hfn string,
inTs *TxStore, p *chaincfg.Params) (SPVCon, error) {
// create new SPVCon
var s SPVCon
// assign network parameters to SPVCon
s.param = p
// open header file
err := s.openHeaderFile(headerFileName)
if err != nil {
return err
return s, err
}
// open TCP connection
s.con, err = net.Dial("tcp", remoteNode)
if err != nil {
return err
return s, err
}
// assign version bits for local node
s.localVersion = VERSION
s.netType = NETVERSION
// transaction store for this SPV connection
s.TS = inTs
myMsgVer, err := wire.NewMsgVersionFromConn(s.con, 0, 0)
if err != nil {
return err
return s, err
}
err = myMsgVer.AddUserAgent("test", "zero")
if err != nil {
return err
return s, err
}
// must set this to enable SPV stuff
myMsgVer.AddService(wire.SFNodeBloom)
// this actually sends
n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.netType)
n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.param.Net)
if err != nil {
return err
return s, err
}
s.WBytes += uint64(n)
log.Printf("wrote %d byte version message to %s\n",
n, s.con.RemoteAddr().String())
n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.netType)
n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.param.Net)
if err != nil {
return err
return s, err
}
s.RBytes += uint64(n)
log.Printf("got %d byte response %x\n command: %s\n", n, b, m.Command())
@ -99,9 +104,9 @@ func (s *SPVCon) Open(remoteNode string, hfn string, inTs *TxStore) error {
mv.ProtocolVersion, mv.ProtocolVersion)
mva := wire.NewMsgVerAck()
n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.netType)
n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.param.Net)
if err != nil {
return err
return s, err
}
s.WBytes += uint64(n)
@ -110,7 +115,7 @@ func (s *SPVCon) Open(remoteNode string, hfn string, inTs *TxStore) error {
s.outMsgQueue = make(chan wire.Message)
go s.outgoingMessageHandler()
return nil
return s, nil
}
func (s *SPVCon) openHeaderFile(hfn string) error {
@ -118,7 +123,7 @@ func (s *SPVCon) openHeaderFile(hfn string) error {
if err != nil {
if os.IsNotExist(err) {
var b bytes.Buffer
err = params.GenesisBlock.Header.Serialize(&b)
err = s.param.GenesisBlock.Header.Serialize(&b)
if err != nil {
return err
}
@ -252,7 +257,7 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) {
// advance chain tip
tip++
// check last header
worked := CheckHeader(s.headerFile, tip, params)
worked := CheckHeader(s.headerFile, tip, s.param)
if !worked {
if endPos < 8080 {
// jeez I give up, back to genesis
@ -299,15 +304,3 @@ func (s *SPVCon) AskForMerkBlocks(current, last uint32) error {
return nil
}
func sendMBReq(cn net.Conn, blkhash wire.ShaHash) error {
iv1 := wire.NewInvVect(wire.InvTypeFilteredBlock, &blkhash)
gdataB := wire.NewMsgGetData()
_ = gdataB.AddInvVect(iv1)
n, err := wire.WriteMessageN(cn, gdataB, VERSION, NETVERSION)
if err != nil {
return err
}
log.Printf("sent %d byte block request\n", n)
return nil
}

@ -9,7 +9,7 @@ import (
func (s *SPVCon) incomingMessageHandler() {
for {
n, xm, _, err := wire.ReadMessageN(s.con, s.localVersion, s.netType)
n, xm, _, err := wire.ReadMessageN(s.con, s.localVersion, s.param.Net)
if err != nil {
log.Printf("ReadMessageN error. Disconnecting: %s\n", err.Error())
return
@ -79,7 +79,7 @@ func (s *SPVCon) incomingMessageHandler() {
func (s *SPVCon) outgoingMessageHandler() {
for {
msg := <-s.outMsgQueue
n, err := wire.WriteMessageN(s.con, msg, s.localVersion, s.netType)
n, err := wire.WriteMessageN(s.con, msg, s.localVersion, s.param.Net)
if err != nil {
log.Printf("Write message error: %s", err.Error())
}