move elkrem, uspv libs to plasma repo

This commit is contained in:
Tadge Dryja 2016-01-14 19:56:25 -08:00 committed by Olaoluwa Osuntokun
parent b0ce9a06dc
commit 65c7d1c40c
7 changed files with 1026 additions and 0 deletions

157
elkrem/elkrem.go Normal file

@ -0,0 +1,157 @@
package elkrem
import (
"fmt"
"github.com/btcsuite/btcd/wire"
)
/* elkrem is a simpler alternative to the 64 dimensional sha-chain.
it's basically a reverse merkle tree. If we want to provide 2**64 possible
hashes, this requires a worst case computation of 63 hashes for the
sender, and worst-case storage of 64 hashes for the receiver.
The operations are left hash L() and right hash R(), which are
hash(parent) and hash(parent, 1) respectively. (concatenate one byte)
Here is a shorter example of a tree with 8 leaves and 15 total nodes.
The sender first computes the bottom left leaf 0b0000. This is
L(L(L(L(root)))). The receiver stores leaf 0.
Next the sender computes 0b0001. R(L(L(L(root)))). Receiver stores.
Next sender computes 0b1000 (8). L(L(L(root))). Receiver stores this, and
discards leaves 0b0000 and 0b0001, as they have the parent node 8.
For total hashes (2**h)-1 requires a tree of height h.
Sender:
as state, must store 1 hash (root) and current index (h bits)
to move to the next index, compute at most h hashes.
Receiver:
as state, must store at most h+1 hashes and the index of each hash (h*(h+1)) bits
to compute a previous index, compute at most h hashes.
*/
// You can calculate h from i but I can't figure out how without taking
// O(i) ops. Feels like there should be a clever O(h) way. 1 byte, whatever.
type ElkremNode struct {
i uint64 // index (ith node)
h uint8 // height of this node
sha *wire.ShaHash // hash
}
type ElkremSender struct {
current uint64 // last sent hash index
treeHeight uint8 // height of tree (size is 2**height -1 )
maxIndex uint64 // top of the tree
root *wire.ShaHash // root hash of the tree
}
type ElkremReceiver struct {
current uint64 // last received index (actually don't need it?)
treeHeight uint8 // height of tree (size is 2**height -1 )
s []ElkremNode // store of received hashes, max size = height
}
func LeftSha(in wire.ShaHash) wire.ShaHash {
return wire.DoubleSha256SH(in.Bytes()) // left is sha(sha(in))
}
func RightSha(in wire.ShaHash) wire.ShaHash {
return wire.DoubleSha256SH(append(in.Bytes(), 0x01)) // sha(sha(in, 1))
}
// iterative descent of sub-tree. w = hash number you want. i = input index
// h = height of input index. sha = input hash
func descend(w, i uint64, h uint8, sha wire.ShaHash) (wire.ShaHash, error) {
for w < i {
if w <= i-(1<<h) { // left
sha = LeftSha(sha)
i = i - (1 << h) // left descent reduces index by 2**h
} else { // right
sha = RightSha(sha)
i-- // right descent reduces index by 1
}
if h == 0 { // avoid underflowing h
break
}
h-- // either descent reduces height by 1
}
if w != i { // somehow couldn't / didn't end up where we wanted to go
return sha, fmt.Errorf("can't generate index %d from %d", w, i)
}
return sha, nil
}
// Creates an Elkrem Sender from a root hash and tree height
func NewElkremSender(th uint8, r wire.ShaHash) ElkremSender {
var e ElkremSender
e.root = &r
e.treeHeight = th
// set max index based on tree height
for j := uint8(0); j <= e.treeHeight; j++ {
e.maxIndex = e.maxIndex<<1 | 1
}
e.maxIndex-- // 1 less than 2**h
return e
}
// Creates an Elkrem Receiver from a tree height
func NewElkremReceiver(th uint8) ElkremReceiver {
var e ElkremReceiver
e.treeHeight = th
return e
}
// Next() increments the index to the next hash and outputs it
func (e *ElkremSender) Next() (*wire.ShaHash, error) {
// increment index
e.current++
return e.AtIndex(e.current)
}
// w is the wanted index, i is the root index
func (e *ElkremSender) AtIndex(w uint64) (*wire.ShaHash, error) {
out, err := descend(w, e.maxIndex, e.treeHeight, *e.root)
return &out, err
}
func (e *ElkremReceiver) AddNext(sha *wire.ShaHash) error {
// note: careful about atomicity / disk writes here
var n ElkremNode
n.sha = sha
t := len(e.s) - 1 // top of stack
if t > 0 && e.s[t-1].h == e.s[t].h { // top 2 elements are equal height
// next node must be parent; verify and remove children
n.h = e.s[t].h + 1 // assign height
l := LeftSha(*sha) // calc l child
r := RightSha(*sha) // calc r child
if !e.s[t-1].sha.IsEqual(&l) { // test l child
return fmt.Errorf("left child doesn't match, expect %s got %s",
e.s[t-1].sha.String(), l.String())
}
if !e.s[t].sha.IsEqual(&r) { // test r child
return fmt.Errorf("right child doesn't match, expect %s got %s",
e.s[t].sha.String(), r.String())
}
e.s = e.s[:len(e.s)-2] // l and r children OK, remove them
} // if that didn't happen, height defaults to 0
e.current++ // increment current index
n.i = e.current // set new node to that incremented index
e.s = append(e.s, n) // append new node to stack
return nil
}
func (e *ElkremReceiver) AtIndex(w uint64) (*wire.ShaHash, error) {
var out ElkremNode // node we will eventually return
for _, n := range e.s { // go through stack
if w <= n.i { // found one bigger than or equal to what we want
out = n
break
}
}
if out.sha == nil { // didn't find anything
return nil, fmt.Errorf("receiver has max %d, less than requested %d",
e.s[len(e.s)-1].i, w)
}
sha, err := descend(w, out.i, out.h, *out.sha)
return &sha, err
}

263
uspv/eight333.go Normal file

@ -0,0 +1,263 @@
package uspv
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"os"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/bloom"
)
const (
keyFileName = "testseed.hex"
headerFileName = "headers.bin"
// Except hash-160s, those aren't backwards. But anything that's 32 bytes is.
// because, cmon, 32? Gotta reverse that. But 20? 20 is OK.
NETVERSION = wire.TestNetL
VERSION = 70011
)
var (
params = &chaincfg.TestNet3Params
)
type SPVCon struct {
con net.Conn // the (probably tcp) connection to the node
headerFile *os.File // file for SPV headers
localVersion uint32 // version we report
remoteVersion uint32 // version remote node
remoteHeight int32 // block height they're on
netType wire.BitcoinNet
// what's the point of the input queue? remove? leave for now...
inMsgQueue chan wire.Message // Messages coming in from remote node
outMsgQueue chan wire.Message // Messages going out to remote node
WBytes uint64 // total bytes written
RBytes uint64 // total bytes read
}
func (s *SPVCon) Open(remoteNode string, hfn string) error {
// open header file
err := s.openHeaderFile(headerFileName)
if err != nil {
return err
}
// open TCP connection
s.con, err = net.Dial("tcp", remoteNode)
if err != nil {
return err
}
s.localVersion = VERSION
s.netType = NETVERSION
myMsgVer, err := wire.NewMsgVersionFromConn(s.con, 0, 0)
if err != nil {
return err
}
err = myMsgVer.AddUserAgent("test", "zero")
if err != nil {
return err
}
// must set this to enable SPV stuff
myMsgVer.AddService(wire.SFNodeBloom)
// this actually sends
n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.netType)
if err != nil {
return err
}
s.WBytes += uint64(n)
log.Printf("wrote %d byte version message to %s\n",
n, s.con.RemoteAddr().String())
n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.netType)
if err != nil {
return err
}
s.RBytes += uint64(n)
log.Printf("got %d byte response %x\n command: %s\n", n, b, m.Command())
mv, ok := m.(*wire.MsgVersion)
if ok {
log.Printf("connected to %s", mv.UserAgent)
}
log.Printf("remote reports version %x (dec %d)\n",
mv.ProtocolVersion, mv.ProtocolVersion)
mva := wire.NewMsgVerAck()
n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.netType)
if err != nil {
return err
}
s.WBytes += uint64(n)
s.inMsgQueue = make(chan wire.Message)
go s.incomingMessageHandler()
s.outMsgQueue = make(chan wire.Message)
go s.outgoingMessageHandler()
return nil
}
func (s *SPVCon) openHeaderFile(hfn string) error {
_, err := os.Stat(hfn)
if err != nil {
if os.IsNotExist(err) {
var b bytes.Buffer
err = params.GenesisBlock.Header.Serialize(&b)
if err != nil {
return err
}
err = ioutil.WriteFile(hfn, b.Bytes(), 0600)
if err != nil {
return err
}
log.Printf("created hardcoded genesis header at %s\n",
hfn)
}
}
s.headerFile, err = os.OpenFile(hfn, os.O_RDWR, 0600)
if err != nil {
return err
}
log.Printf("opened header file %s\n", s.headerFile.Name())
return nil
}
func (s *SPVCon) PongBack(nonce uint64) {
mpong := wire.NewMsgPong(nonce)
s.outMsgQueue <- mpong
return
}
func (s *SPVCon) SendFilter(f *bloom.Filter) {
s.outMsgQueue <- f.MsgFilterLoad()
return
}
func (s *SPVCon) GrabHeaders() error {
var hdr wire.BlockHeader
ghdr := wire.NewMsgGetHeaders()
ghdr.ProtocolVersion = s.localVersion
info, err := s.headerFile.Stat()
if err != nil {
return err
}
headerFileSize := info.Size()
if headerFileSize == 0 || headerFileSize%80 != 0 { // header file broken
return fmt.Errorf("Header file not a multiple of 80 bytes")
}
// seek to 80 bytes from end of file
ns, err := s.headerFile.Seek(-80, os.SEEK_END)
if err != nil {
log.Printf("can't seek\n")
return err
}
log.Printf("suk to offset %d (should be near the end\n", ns)
// get header from last 80 bytes of file
err = hdr.Deserialize(s.headerFile)
if err != nil {
log.Printf("can't Deserialize")
return err
}
cHash := hdr.BlockSha()
err = ghdr.AddBlockLocatorHash(&cHash)
if err != nil {
return err
}
fmt.Printf("get headers message has %d header hashes, first one is %s\n",
len(ghdr.BlockLocatorHashes), ghdr.BlockLocatorHashes[0].String())
s.outMsgQueue <- ghdr
return nil
// =============================================================
// ask for headers. probably will get 2000.
log.Printf("getheader version %d \n", ghdr.ProtocolVersion)
n, m, _, err := wire.ReadMessageN(s.con, VERSION, NETVERSION)
if err != nil {
return err
}
log.Printf("4got %d byte response\n command: %s\n", n, m.Command())
hdrresponse, ok := m.(*wire.MsgHeaders)
if !ok {
log.Printf("got non-header message.")
return nil
// this can acutally happen and we should deal with / ignore it
// also pings, they don't like it when you don't respond to pings.
// invs and the rest we can ignore for now until filters are up.
}
_, err = s.headerFile.Seek(-80, os.SEEK_END)
if err != nil {
return err
}
var last wire.BlockHeader
err = last.Deserialize(s.headerFile)
if err != nil {
return err
}
prevHash := last.BlockSha()
gotNum := int64(len(hdrresponse.Headers))
if gotNum > 0 {
fmt.Printf("got %d headers. Range:\n%s - %s\n",
gotNum, hdrresponse.Headers[0].BlockSha().String(),
hdrresponse.Headers[len(hdrresponse.Headers)-1].BlockSha().String())
}
_, err = s.headerFile.Seek(0, os.SEEK_END)
if err != nil {
return err
}
for i, resphdr := range hdrresponse.Headers {
// check first header returned to make sure it fits on the end
// of our header file
if i == 0 && !resphdr.PrevBlock.IsEqual(&prevHash) {
return fmt.Errorf("header doesn't fit. points to %s, expect %s",
resphdr.PrevBlock.String(), prevHash.String())
}
err = resphdr.Serialize(s.headerFile)
if err != nil {
return err
}
}
endPos, _ := s.headerFile.Seek(0, os.SEEK_END)
tip := endPos / 80
go CheckRange(s.headerFile, tip-gotNum, tip-1, params)
return nil
}
func sendMBReq(cn net.Conn, blkhash wire.ShaHash) error {
iv1 := wire.NewInvVect(wire.InvTypeFilteredBlock, &blkhash)
gdataB := wire.NewMsgGetData()
_ = gdataB.AddInvVect(iv1)
n, err := wire.WriteMessageN(cn, gdataB, VERSION, NETVERSION)
if err != nil {
return err
}
log.Printf("sent %d byte block request\n", n)
return nil
}

56
uspv/filter.go Normal file

@ -0,0 +1,56 @@
package uspv
import (
"log"
"net"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil/bloom"
)
//func (w *Eight3Con) SendFilter() error {
// filter := bloom.NewFilter(10, 0, 0.001, wire.BloomUpdateAll)
// // add addresses.
// filter.Add(adrBytes1)
// filter.Add(adrBytes2)
// filter.Add(adrBytes3)
// // filter.Add(adrBytes4)
// fmt.Printf("hf: %d filter %d bytes: %x\n",
// filter.MsgFilterLoad().HashFuncs,
// len(filter.MsgFilterLoad().Filter), filter.MsgFilterLoad().Filter)
// n, err := wire.WriteMessageN(cn,
// filter.MsgFilterLoad(), myversion, whichnet)
// if err != nil {
// return err
// }
// log.Printf("sent %d byte filter message\n", n)
// return nil
//}
func sendFilter(cn net.Conn) error {
// adrBytes1, _ := hex.DecodeString(adrHex1)
// adrBytes2, _ := hex.DecodeString(adrHex2)
// adrBytes3, _ := hex.DecodeString(adrHex3)
// ------------------- load a filter
// make a new filter. floats ew. hardcode.
filter := bloom.NewFilter(10, 0, 0.001, wire.BloomUpdateNone)
// add addresses.
// filter.Add(adrBytes1)
// filter.Add(adrBytes2)
// filter.Add(adrBytes3)
// filter.Add(adrBytes4)
n, err := wire.WriteMessageN(cn,
filter.MsgFilterLoad(), VERSION, NETVERSION)
if err != nil {
return err
}
log.Printf("sent %d byte filter message\n", n)
return nil
}

194
uspv/header.go Normal file

@ -0,0 +1,194 @@
/* this is blockchain technology. Well, except without the blocks.
Really it's header chain technology.
The blocks themselves don't really make a chain. Just the headers do.
*/
package uspv
import (
"io"
"log"
"math/big"
"os"
"time"
"github.com/btcsuite/btcd/blockchain"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire"
)
// blockchain settings. These are kindof bitcoin specific, but not contained in
// chaincfg.Params so they'll go here. If you're into the [ANN]altcoin scene,
// you may want to paramaterize these constants.
const (
targetTimespan = time.Hour * 24 * 14
targetSpacing = time.Minute * 10
epochLength = int64(targetTimespan / targetSpacing)
maxDiffAdjust = 4
minRetargetTimespan = int64(targetTimespan / maxDiffAdjust)
maxRetargetTimespan = int64(targetTimespan * maxDiffAdjust)
)
/* checkProofOfWork verifies the header hashes into something
lower than specified by the 4-byte bits field. */
func checkProofOfWork(header wire.BlockHeader, p *chaincfg.Params) bool {
target := blockchain.CompactToBig(header.Bits)
// The target must more than 0. Why can you even encode negative...
if target.Sign() <= 0 {
log.Printf("block target %064x is neagtive(??)\n", target.Bytes())
return false
}
// The target must be less than the maximum allowed (difficulty 1)
if target.Cmp(p.PowLimit) > 0 {
log.Printf("block target %064x is "+
"higher than max of %064x", target, p.PowLimit.Bytes())
return false
}
// The header hash must be less than the claimed target in the header.
blockHash := header.BlockSha()
hashNum := blockchain.ShaHashToBig(&blockHash)
if hashNum.Cmp(target) > 0 {
log.Printf("block hash %064x is higher than "+
"required target of %064x", hashNum, target)
return false
}
return true
}
/* calcDiff returns a bool given two block headers. This bool is
true if the correct dificulty adjustment is seen in the "next" header.
Only feed it headers n-2016 and n-1, otherwise it will calculate a difficulty
when no adjustment should take place, and return false.
Note that the epoch is actually 2015 blocks long, which is confusing. */
func calcDiffAdjust(start, end wire.BlockHeader, p *chaincfg.Params) uint32 {
duration := end.Timestamp.UnixNano() - start.Timestamp.UnixNano()
if duration < minRetargetTimespan {
log.Printf("whoa there, block %s off-scale high 4X diff adjustment!",
end.BlockSha().String())
duration = minRetargetTimespan
} else if duration > maxRetargetTimespan {
log.Printf("Uh-oh! block %s off-scale low 0.25X diff adjustment!\n",
end.BlockSha().String())
duration = maxRetargetTimespan
}
// calculation of new 32-byte difficulty target
// first turn the previous target into a big int
prevTarget := blockchain.CompactToBig(start.Bits)
// new target is old * duration...
newTarget := new(big.Int).Mul(prevTarget, big.NewInt(duration))
// divided by 2 weeks
newTarget.Div(newTarget, big.NewInt(int64(targetTimespan)))
// clip again if above minimum target (too easy)
if newTarget.Cmp(p.PowLimit) > 0 {
newTarget.Set(p.PowLimit)
}
// calculate and return 4-byte 'bits' difficulty from 32-byte target
return blockchain.BigToCompact(newTarget)
}
func CheckHeader(r io.ReadSeeker, height int64, p *chaincfg.Params) bool {
var err error
var cur, prev, epochStart wire.BlockHeader
// don't try to verfy the genesis block. That way madness lies.
if height == 0 {
return true
}
// initial load of headers
// load epochstart, previous and current.
// get the header from the epoch start, up to 2016 blocks ago
_, err = r.Seek(80*(height-(height%epochLength)), os.SEEK_SET)
if err != nil {
log.Printf(err.Error())
return false
}
err = epochStart.Deserialize(r)
if err != nil {
log.Printf(err.Error())
return false
}
log.Printf("start epoch at height %d ", height-(height%epochLength))
// seek to n-1 header
_, err = r.Seek(80*(height-1), os.SEEK_SET)
if err != nil {
log.Printf(err.Error())
return false
}
// read in n-1
err = prev.Deserialize(r)
if err != nil {
log.Printf(err.Error())
return false
}
// seek to curHeight header and read in
_, err = r.Seek(80*(height), os.SEEK_SET)
if err != nil {
log.Printf(err.Error())
return false
}
err = cur.Deserialize(r)
if err != nil {
log.Printf(err.Error())
return false
}
// get hash of n-1 header
prevHash := prev.BlockSha()
// check if headers link together. That whole 'blockchain' thing.
if prevHash.IsEqual(&cur.PrevBlock) == false {
log.Printf("Headers %d and %d don't link.\n",
height-1, height)
log.Printf("%s - %s",
prev.BlockSha().String(), cur.BlockSha().String())
return false
}
rightBits := epochStart.Bits // normal, no adjustment; Dn = Dn-1
// see if we're on a difficulty adjustment block
if (height)%epochLength == 0 {
// if so, check if difficulty adjustment is valid.
// That whole "controlled supply" thing.
// calculate diff n based on n-2016 ... n-1
rightBits = calcDiffAdjust(epochStart, prev, p)
// done with adjustment, save new ephochStart header
epochStart = cur
log.Printf("Update epoch at height %d", height)
} else { // not a new epoch
// if on testnet, check for difficulty nerfing
if p.ResetMinDifficulty && cur.Timestamp.After(
prev.Timestamp.Add(targetSpacing*2)) {
// fmt.Printf("nerf %d ", curHeight)
rightBits = p.PowLimitBits // difficulty 1
}
if cur.Bits != rightBits {
log.Printf("Block %d %s incorrect difficuly. Read %x, expect %x\n",
height, cur.BlockSha().String(), cur.Bits, rightBits)
return false
}
}
// check if there's a valid proof of work. That whole "Bitcoin" thing.
if !checkProofOfWork(cur, p) {
log.Printf("Block %d Bad proof of work.\n", height)
return false
}
return true // it must have worked if there's no errors and got to the end.
}
/* checkrange verifies a range of headers. it checks their proof of work,
difficulty adjustments, and that they all link in to each other properly.
This is the only blockchain technology in the whole code base.
Returns false if anything bad happens. Returns true if the range checks
out with no errors. */
func CheckRange(r io.ReadSeeker, first, last int64, p *chaincfg.Params) bool {
for i := first; i <= last; i++ {
if !CheckHeader(r, i, p) {
return false
}
}
return true // all good.
}

169
uspv/mblock.go Normal file

@ -0,0 +1,169 @@
package uspv
import (
"fmt"
"github.com/btcsuite/btcd/wire"
)
func MakeMerkleParent(left *wire.ShaHash, right *wire.ShaHash) *wire.ShaHash {
// this can screw things up; CVE-2012-2459
if left != nil && right != nil && left.IsEqual(right) {
fmt.Printf("DUP HASH CRASH")
return nil
}
// if left chils is nil, output nil. Shouldn't need this?
if left == nil {
fmt.Printf("L CRASH")
return nil
}
// if right is nil, has left with itself
if right == nil {
right = left
}
// Concatenate the left and right nodes
var sha [wire.HashSize * 2]byte
copy(sha[:wire.HashSize], left[:])
copy(sha[wire.HashSize:], right[:])
newSha := wire.DoubleSha256SH(sha[:])
return &newSha
}
type merkleNode struct {
p uint32 // position in the binary tree
h *wire.ShaHash // hash
}
// given n merkle leaves, how deep is the tree?
// iterate shifting left until greater than n
func treeDepth(n uint32) (e uint8) {
for ; (1 << e) < n; e++ {
}
return
}
// smallest power of 2 that can contain n
func nextPowerOfTwo(n uint32) uint32 {
return 1 << treeDepth(n) // 2^exponent
}
// check if a node is populated based on node position and size of tree
func inDeadZone(pos, size uint32) bool {
msb := nextPowerOfTwo(size)
last := size - 1 // last valid position is 1 less than size
if pos > (msb<<1)-2 { // greater than root; not even in the tree
fmt.Printf(" ?? greater than root ")
return true
}
h := msb
for pos >= h {
h = h>>1 | msb
last = last>>1 | msb
}
return pos > last
}
// take in a merkle block, parse through it, and return the
// txids that they're trying to tell us about. If there's any problem
// return an error.
// doing it with a stack instead of recursion. Because...
// OK I don't know why I'm just not in to recursion OK?
func checkMBlock(m *wire.MsgMerkleBlock) ([]*wire.ShaHash, error) {
if m.Transactions == 0 {
return nil, fmt.Errorf("No transactions in merkleblock")
}
if len(m.Flags) == 0 {
return nil, fmt.Errorf("No flag bits")
}
var s []merkleNode // the stack
var r []*wire.ShaHash // slice to return; txids we care about
// set initial position to root of merkle tree
msb := nextPowerOfTwo(m.Transactions) // most significant bit possible
pos := (msb << 1) - 2 // current position in tree
var i uint8 // position in the current flag byte
var tip int
// main loop
for {
tip = len(s) - 1 // slice position of stack tip
// First check if stack operations can be performed
// is stack one filled item? that's complete.
if tip == 0 && s[0].h != nil {
if s[0].h.IsEqual(&m.Header.MerkleRoot) {
return r, nil
}
return nil, fmt.Errorf("computed root %s but expect %s\n",
s[0].h.String(), m.Header.MerkleRoot.String())
}
// is current position in the tree's dead zone? partial parent
if inDeadZone(pos, m.Transactions) {
// create merkle parent from single side (left)
s[tip-1].h = MakeMerkleParent(s[tip].h, nil)
s = s[:tip] // remove 1 from stack
pos = s[tip-1].p | 1 // move position to parent's sibling
continue
}
// does stack have 3+ items? and are last 2 items filled?
if tip > 1 && s[tip-1].h != nil && s[tip].h != nil {
//fmt.Printf("nodes %d and %d combine into %d\n",
// s[tip-1].p, s[tip].p, s[tip-2].p)
// combine two filled nodes into parent node
s[tip-2].h = MakeMerkleParent(s[tip-1].h, s[tip].h)
// remove children
s = s[:tip-1]
// move position to parent's sibling
pos = s[tip-2].p | 1
continue
}
// no stack ops to perform, so make new node from message hashes
if len(m.Hashes) == 0 {
return nil, fmt.Errorf("Ran out of hashes at position %d.", pos)
}
if len(m.Flags) == 0 {
return nil, fmt.Errorf("Ran out of flag bits.")
}
var n merkleNode // make new node
n.p = pos // set current position for new node
if pos&msb != 0 { // upper non-txid hash
if m.Flags[0]&(1<<i) == 0 { // flag bit says fill node
n.h = m.Hashes[0] // copy hash from message
m.Hashes = m.Hashes[1:] // pop off message
if pos&1 != 0 { // right side; ascend
pos = pos>>1 | msb
} else { // left side, go to sibling
pos |= 1
}
} else { // flag bit says skip; put empty on stack and descend
pos = (pos ^ msb) << 1 // descend to left
}
s = append(s, n) // push new node on stack
} else { // bottom row txid; flag bit indicates tx of interest
if pos >= m.Transactions {
// this can't happen because we check deadzone above...
return nil, fmt.Errorf("got into an invalid txid node")
}
n.h = m.Hashes[0] // copy hash from message
m.Hashes = m.Hashes[1:] // pop off message
if m.Flags[0]&(1<<i) != 0 { //txid of interest
r = append(r, n.h)
}
if pos&1 == 0 { // left side, go to sibling
pos |= 1
} // if on right side we don't move; stack ops will move next
s = append(s, n) // push new node onto the stack
}
// done with pushing onto stack; advance flag bit
i++
if i == 8 { // move to next byte
i = 0
m.Flags = m.Flags[1:]
}
}
return nil, fmt.Errorf("ran out of things to do?")
}

68
uspv/msghandler.go Normal file

@ -0,0 +1,68 @@
package uspv
import (
"fmt"
"log"
"github.com/btcsuite/btcd/wire"
)
func (e3c *SPVCon) incomingMessageHandler() {
for {
n, xm, _, err := wire.ReadMessageN(e3c.con, e3c.localVersion, e3c.netType)
if err != nil {
log.Printf("ReadMessageN error. Disconnecting: %s\n", err.Error())
return
}
e3c.RBytes += uint64(n)
// log.Printf("Got %d byte %s message\n", n, xm.Command())
switch m := xm.(type) {
case *wire.MsgVersion:
log.Printf("Got version message. Agent %s, version %d, at height %d\n",
m.UserAgent, m.ProtocolVersion, m.LastBlock)
e3c.remoteVersion = uint32(m.ProtocolVersion) // weird cast! bug?
case *wire.MsgVerAck:
log.Printf("Got verack. Whatever.\n")
case *wire.MsgAddr:
log.Printf("got %d addresses.\n", len(m.AddrList))
case *wire.MsgPing:
log.Printf("Got a ping message. We should pong back or they will kick us off.")
e3c.PongBack(m.Nonce)
case *wire.MsgPong:
log.Printf("Got a pong response. OK.\n")
case *wire.MsgMerkleBlock:
log.Printf("Got merkle block message. Will verify.\n")
fmt.Printf("%d flag bytes, %d txs, %d hashes",
len(m.Flags), m.Transactions, len(m.Hashes))
txids, err := checkMBlock(m)
if err != nil {
log.Printf("Merkle block error: %s\n", err.Error())
return
// continue
}
fmt.Printf(" = got %d txs from block %s\n",
len(txids), m.Header.BlockSha().String())
// nextReq <- true
case *wire.MsgTx:
log.Printf("Got tx %s\n", m.TxSha().String())
default:
log.Printf("Got unknown message type %s\n", m.Command())
}
}
return
}
// this one seems kindof pointless? could get ridf of it and let
// functions call WriteMessageN themselves...
func (e3c *SPVCon) outgoingMessageHandler() {
for {
msg := <-e3c.outMsgQueue
n, err := wire.WriteMessageN(e3c.con, msg, e3c.localVersion, e3c.netType)
if err != nil {
log.Printf("Write message error: %s", err.Error())
}
e3c.WBytes += uint64(n)
}
return
}

119
uspv/txstore.go Normal file

@ -0,0 +1,119 @@
package uspv
import (
"bytes"
"fmt"
"log"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/btcutil/bloom"
)
type TxStore struct {
KnownTxids []wire.ShaHash
Utxos []Utxo // stacks on stacks
Sum int64 // racks on racks
Adrs []MyAdr // endeavouring to acquire capital
}
type Utxo struct { // cash money.
// combo of outpoint and txout which has all the info needed to spend
Op wire.OutPoint
Txo wire.TxOut
KeyIdx uint32 // index for private key needed to sign / spend
}
type MyAdr struct { // an address I have the private key for
btcutil.Address
KeyIdx uint32 // index for private key needed to sign / spend
}
func (t *TxStore) AddAdr(a btcutil.Address, kidx uint32) {
var ma MyAdr
ma.Address = a
ma.KeyIdx = kidx
t.Adrs = append(t.Adrs, ma)
return
}
func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
if len(t.Adrs) == 0 {
return nil, fmt.Errorf("no addresses to filter for")
}
f := bloom.NewFilter(uint32(len(t.Adrs)), 0, 0.001, wire.BloomUpdateNone)
for _, a := range t.Adrs {
f.Add(a.ScriptAddress())
}
return f, nil
}
// Ingest a tx into wallet, dealing with both gains and losses
func (t *TxStore) IngestTx(tx *wire.MsgTx) error {
err := t.AbsorbTx(tx)
if err != nil {
return err
}
err = t.ExpellTx(tx)
if err != nil {
return err
}
return nil
}
// Absorb money into wallet from a tx
func (t *TxStore) AbsorbTx(tx *wire.MsgTx) error {
if tx == nil {
return fmt.Errorf("Tried to add nil tx")
}
var hits uint32
var acq int64
// check if any of the tx's outputs match my adrs
for i, out := range tx.TxOut { // in each output of tx
for _, a := range t.Adrs { // compare to each adr we have
// more correct would be to check for full script
// contains could have false positive? (p2sh/p2pkh same hash ..?)
if bytes.Contains(out.PkScript, a.ScriptAddress()) { // hit
hits++
acq += out.Value
var newu Utxo
newu.KeyIdx = a.KeyIdx
newu.Txo = *out
var newop wire.OutPoint
newop.Hash = tx.TxSha()
newop.Index = uint32(i)
newu.Op = newop
t.Utxos = append(t.Utxos)
break
}
}
}
log.Printf("%d hits, acquired %d", hits, acq)
t.Sum += acq
return nil
}
// Expell money from wallet due to a tx
func (t *TxStore) ExpellTx(tx *wire.MsgTx) error {
if tx == nil {
return fmt.Errorf("Tried to add nil tx")
}
var hits uint32
var loss int64
for _, in := range tx.TxIn {
for i, myutxo := range t.Utxos {
if myutxo.Op == in.PreviousOutPoint {
hits++
loss += myutxo.Txo.Value
// delete from my utxo set
t.Utxos = append(t.Utxos[:i], t.Utxos[i+1:]...)
}
}
}
log.Printf("%d hits, lost %d", hits, loss)
t.Sum -= loss
return nil
}