diff --git a/.gitignore b/.gitignore index feb149c7..b55095c7 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,7 @@ test_wal/* # vim **.swp + +*.hex +*.db +*.bin diff --git a/cmd/lnshell/commands.go b/cmd/lnshell/commands.go index a883c9be..dc4e1506 100644 --- a/cmd/lnshell/commands.go +++ b/cmd/lnshell/commands.go @@ -22,8 +22,8 @@ func RpcConnect(args []string) error { } fmt.Printf("connection state: %s\n", state.String()) time.Sleep(time.Second * 2) - // lnClient := lnrpc.NewLightningClient(conn) - // lnClient.NewAddress(nil, nil, nil) // crashes + // lnClient := lnrpc.NewLightningClient(conn) + // lnClient.NewAddress(nil, nil, nil) // crashes state, err = conn.State() if err != nil { diff --git a/lnd.go b/lnd.go index afcc516f..5751379b 100644 --- a/lnd.go +++ b/lnd.go @@ -21,11 +21,17 @@ var ( rpcPort = flag.Int("rpcport", 10009, "The port for the rpc server") peerPort = flag.String("peerport", "10011", "The port to listen on for incoming p2p connections") dataDir = flag.String("datadir", "test_wal", "The directory to store lnd's data within") + spvMode = flag.Bool("spv", false, "assert to enter spv wallet mode") ) func main() { flag.Parse() + if *spvMode == true { + shell() + return + } + go func() { listenAddr := net.JoinHostPort("", "5009") profileRedirect := http.RedirectHandler("/debug/pprof", diff --git a/shell.go b/shell.go new file mode 100644 index 00000000..5602a18c --- /dev/null +++ b/shell.go @@ -0,0 +1,328 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "log" + "os" + "strconv" + "strings" + + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcutil" + + "github.com/lightningnetwork/lnd/uspv" +) + +/* this is a CLI shell for testing out LND. Right now it's only for uspv +testing. It can send and receive coins. +*/ + +const ( + keyFileName = "testkey.hex" + headerFileName = "headers.bin" + dbFileName = "utxo.db" + // this is my local testnet node, replace it with your own close by. + // Random internet testnet nodes usually work but sometimes don't, so + // maybe I should test against different versions out there. + SPVHostAdr = "127.0.0.1:18333" +) + +var ( + Params = &chaincfg.TestNet3Params + SCon uspv.SPVCon // global here for now +) + +func shell() { + fmt.Printf("LND spv shell v0.0\n") + fmt.Printf("Not yet well integrated, but soon.\n") + + // read key file (generate if not found) + rootPriv, err := uspv.ReadKeyFileToECPriv(keyFileName, Params) + if err != nil { + log.Fatal(err) + } + // setup TxStore first (before spvcon) + Store := uspv.NewTxStore(rootPriv, Params) + // setup spvCon + + SCon, err = uspv.OpenSPV( + SPVHostAdr, headerFileName, dbFileName, &Store, Params) + if err != nil { + log.Fatal(err) + } + + tip, err := SCon.TS.GetDBSyncHeight() // ask for sync height + if err != nil { + log.Fatal(err) + } + if tip == 0 { // DB has never been used, set to birthday + tip = 675000 // hardcoded; later base on keyfile date? + err = SCon.TS.SetDBSyncHeight(tip) + if err != nil { + log.Fatal(err) + } + } + + // once we're connected, initiate headers sync + err = Hdr() + if err != nil { + log.Fatal(err) + } + + // main shell loop + for { + // setup reader with max 4K input chars + reader := bufio.NewReaderSize(os.Stdin, 4000) + fmt.Printf("LND# ") // prompt + msg, err := reader.ReadString('\n') // input finishes on enter key + if err != nil { + log.Fatal(err) + } + + cmdslice := strings.Fields(msg) // chop input up on whitespace + if len(cmdslice) < 1 { + continue // no input, just prompt again + } + fmt.Printf("entered command: %s\n", msg) // immediate feedback + err = Shellparse(cmdslice) + if err != nil { // only error should be user exit + log.Fatal(err) + } + } + return +} + +// Shellparse parses user input and hands it to command functions if matching +func Shellparse(cmdslice []string) error { + var err error + var args []string + cmd := cmdslice[0] + if len(cmdslice) > 1 { + args = cmdslice[1:] + } + if cmd == "exit" || cmd == "quit" { + return fmt.Errorf("User exit") + } + + // help gives you really terse help. Just a list of commands. + if cmd == "help" { + err = Help(args) + if err != nil { + fmt.Printf("help error: %s\n", err) + } + return nil + } + + // adr generates a new address and displays it + if cmd == "adr" { + err = Adr(args) + if err != nil { + fmt.Printf("adr error: %s\n", err) + } + return nil + } + + // bal shows the current set of utxos, addresses and score + if cmd == "bal" { + err = Bal(args) + if err != nil { + fmt.Printf("bal error: %s\n", err) + } + return nil + } + + // send sends coins to the address specified + if cmd == "send" { + err = Send(args) + if err != nil { + fmt.Printf("send error: %s\n", err) + } + return nil + } + + fmt.Printf("Command not recognized. type help for command list.\n") + return nil +} + +// Hdr asks for headers. +func Hdr() error { + if SCon.RBytes == 0 { + return fmt.Errorf("No SPV connection, can't get headers.") + } + err := SCon.AskForHeaders() + if err != nil { + return err + } + return nil +} + +// Bal prints out your score. +func Bal(args []string) error { + if SCon.TS == nil { + return fmt.Errorf("Can't get balance, spv connection broken") + } + fmt.Printf(" ----- Account Balance ----- \n") + allUtxos, err := SCon.TS.GetAllUtxos() + if err != nil { + return err + } + var score int64 + for i, u := range allUtxos { + fmt.Printf("\tutxo %d height %d %s key: %d amt %d\n", + i, u.AtHeight, u.Op.String(), u.KeyIdx, u.Value) + score += u.Value + } + height, _ := SCon.TS.GetDBSyncHeight() + + for i, a := range SCon.TS.Adrs { + fmt.Printf("address %d %s\n", i, a.PkhAdr.String()) + } + fmt.Printf("Total known utxos: %d\n", len(allUtxos)) + fmt.Printf("Total spendable coin: %d\n", score) + fmt.Printf("DB sync height: %d\n", height) + return nil +} + +// Adr makes a new address. +func Adr(args []string) error { + a, err := SCon.TS.NewAdr() + if err != nil { + return err + } + fmt.Printf("made new address %s, %d addresses total\n", + a.String(), len(SCon.TS.Adrs)) + + return nil +} + +// Send sends coins. +func Send(args []string) error { + // get all utxos from the database + allUtxos, err := SCon.TS.GetAllUtxos() + if err != nil { + return err + } + var score int64 // score is the sum of all utxo amounts. highest score wins. + // add all the utxos up to get the score + for _, u := range allUtxos { + score += u.Value + } + + // score is 0, cannot unlock 'send coins' acheivement + if score == 0 { + return fmt.Errorf("You don't have money. Work hard.") + } + // need args, fail + if len(args) < 2 { + return fmt.Errorf("need args: ssend amount(satoshis) address") + } + amt, err := strconv.ParseInt(args[0], 10, 64) + if err != nil { + return err + } + if amt < 1000 { + return fmt.Errorf("can't send %d, too small", amt) + } + adr, err := btcutil.DecodeAddress(args[1], SCon.TS.Param) + if err != nil { + fmt.Printf("error parsing %s as address\t", args[1]) + return err + } + fmt.Printf("send %d to address: %s \n", + amt, adr.String()) + err = SendCoins(SCon, adr, amt) + if err != nil { + return err + } + return nil +} + +// SendCoins does send coins, but it's very rudimentary +func SendCoins(s uspv.SPVCon, adr btcutil.Address, sendAmt int64) error { + var err error + var score int64 + allUtxos, err := s.TS.GetAllUtxos() + if err != nil { + return err + } + + for _, utxo := range allUtxos { + score += utxo.Value + } + // important rule in bitcoin, output total > input total is invalid. + if sendAmt > score { + return fmt.Errorf("trying to send %d but %d available.", + sendAmt, score) + } + + tx := wire.NewMsgTx() // make new tx + // make address script 76a914...88ac + adrScript, err := txscript.PayToAddrScript(adr) + if err != nil { + return err + } + // make user specified txout and add to tx + txout := wire.NewTxOut(sendAmt, adrScript) + tx.AddTxOut(txout) + + nokori := sendAmt // nokori is how much is needed on input side + for _, utxo := range allUtxos { + // generate pkscript to sign + prevPKscript, err := txscript.PayToAddrScript( + s.TS.Adrs[utxo.KeyIdx].PkhAdr) + if err != nil { + return err + } + // make new input from this utxo + thisInput := wire.NewTxIn(&utxo.Op, prevPKscript) + tx.AddTxIn(thisInput) + nokori -= utxo.Value + if nokori < -10000 { // minimum overage / fee is 1K now + break + } + } + // there's enough left to make a change output + if nokori < -200000 { + change, err := s.TS.NewAdr() + if err != nil { + return err + } + + changeScript, err := txscript.PayToAddrScript(change) + if err != nil { + return err + } + changeOut := wire.NewTxOut((-100000)-nokori, changeScript) + tx.AddTxOut(changeOut) + } + + // use txstore method to sign + err = s.TS.SignThis(tx) + if err != nil { + return err + } + + fmt.Printf("tx: %s", uspv.TxToString(tx)) + buf := bytes.NewBuffer(make([]byte, 0, tx.SerializeSize())) + tx.Serialize(buf) + fmt.Printf("tx: %x\n", buf.Bytes()) + + // send it out on the wire. hope it gets there. + // we should deal with rejects. Don't yet. + err = s.PushTx(tx) + if err != nil { + return err + } + return nil +} + +func Help(args []string) error { + fmt.Printf("commands:\n") + fmt.Printf("help adr bal send exit\n") + return nil +} diff --git a/uspv/README.md b/uspv/README.md new file mode 100644 index 00000000..08cade5e --- /dev/null +++ b/uspv/README.md @@ -0,0 +1,70 @@ +# uspv - micro-SPV library + +The uspv library implements simplified SPV wallet functionality. +It connects to full nodes using the standard port 8333 bitcoin protocol, +gets headers, uses bloom filters, gets blocks and transactions, and has +functions to send and receive coins. + +## Files + +Three files are used by the library: + +#### Key file (currently testkey.hex) + +This file contains the secret seed which creates all private keys used by the wallet. It is stored in ascii hexadecimal for easy copying and pasting. If you don't enter a password when prompted, you'll get a warning and the key file will be saved in the clear with no encryption. You shouldn't do that though. When using a password, the key file will be longer and use the scrypt KDF and nacl/secretbox to protect the secret seed. + +#### Header file (currently headers.bin) + +This is a file storing all the block headers. Headers are 80 bytes long, so this file's size will always be an even multiple of 80. All blockchain-technology verifications are performed when appending headers to the file. In the case of re-orgs, since it's so quick to get headers, it just truncates a bit and tries again. + +#### Database file (currently utxo.db) + +This file more complex. It uses bolt DB to store wallet information needed to send and receive bitcoins. The database file is organized into 4 main "buckets": + +* Utxos ("DuffelBag") + +This bucket stores all the utxos. The goal of bitcoin is to get lots of utxos, earning a high score. + +* Stxos ("SpentTxs") + +For record keeping, this bucket stores what used to be utxos, but are no longer "u"txos, and are spent outpoints. It references the spending txid. + +* Txns ("Txns") + +This bucket stores full serialized transactions which are refenced in the Stxos bucket. These can be used to re-play transactions in the case of re-orgs. + +* State ("MiscState") + +This has describes some miscellaneous global state variables of the database, such as what height it has synchronized up to, and how many addresses have been created. (Currently those are the only 2 things stored) + +## Synchronization overview + +Currently uspv only connects to one hard-coded node, as address messages and storage are not yet implemented. It first asks for headers, providing the last known header (writing the genesis header if needed). It loops through asking for headers until it receives an empty header message, which signals that headers are fully synchronized. + +After header synchronization is complete, it requests merkle blocks starting at the keyfile birthday. (This is currently hard-coded; add new db key?) Bloom filters are generated for the addresses and utxos known to the wallet. If too many false positives are received, a new filter is generated and sent. (This happens fairly often because the filter exponentially saturates with false positives when using BloomUpdateAll.) Once the merkle blocks have been received up to the header height, the wallet is considered synchronized and it will listen for new inv messages from the remote node. An inv message describing a block will trigger a request for headers, starting the same synchronization process of headers then merkle-blocks. + +## TODO + +There's still quite a bit left, though most of it hopefully won't be too hard. + +Problems / still to do: + +* Only connects to one node, and that node is hard-coded. +* Re-orgs affect only headers, and don't evict confirmed transactions. +* Double spends are not detected; Double spent txs will stay at height 0. +* Tx creation and signing is still very rudimentary. +* There may be wire-protocol irregularities which can get it kicked off. + +Hopefully I can get most of that list deleted soon. + +(Now sanity checks txs, but can't check sigs... because it's SPV. Right.) + +Later functionality to implement: + +* "Desktop Mode" SPV, or "Unfiltered" SPV or some other name + +This would be a mode where uspv doesn't use bloom filters and request merkle blocks, but instead grabs everything in the block and discards most of the data. This prevents nodes from learning about your utxo set. To further enhance this, it should connect to multiple nodes and relay txs and inv messages to blend in. + +* Ironman SPV + +Never request txs. Only merkleBlocks (or in above mode, blocks). No unconfirmed transactions are presented to the user, which makes a whole lot of sense as with unconfirmed SPV transactions you're relying completely on the honesty of the reporting node. \ No newline at end of file diff --git a/uspv/eight333.go b/uspv/eight333.go index 4b6ec382..a757c4ab 100644 --- a/uspv/eight333.go +++ b/uspv/eight333.go @@ -7,6 +7,7 @@ import ( "log" "net" "os" + "sync" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/wire" @@ -18,22 +19,22 @@ const ( headerFileName = "headers.bin" // Except hash-160s, those aren't backwards. But anything that's 32 bytes is. // because, cmon, 32? Gotta reverse that. But 20? 20 is OK. - NETVERSION = wire.TestNet3 - VERSION = 70011 -) -var ( - params = &chaincfg.TestNet3Params + // version hardcoded for now, probably ok...? + VERSION = 70011 ) type SPVCon struct { - con net.Conn // the (probably tcp) connection to the node - headerFile *os.File // file for SPV headers + con net.Conn // the (probably tcp) connection to the node + headerMutex sync.Mutex + headerFile *os.File // file for SPV headers + + //[doesn't work without fancy mutexes, nevermind, just use header file] + // localHeight int32 // block height we're on + remoteHeight int32 // block height they're on localVersion uint32 // version we report remoteVersion uint32 // version remote node - remoteHeight int32 // block height they're on - netType wire.BitcoinNet // what's the point of the input queue? remove? leave for now... inMsgQueue chan wire.Message // Messages coming in from remote node @@ -42,50 +43,71 @@ type SPVCon struct { WBytes uint64 // total bytes written RBytes uint64 // total bytes read - TS *TxStore + TS *TxStore // transaction store to write to + + // mBlockQueue is for keeping track of what height we've requested. + mBlockQueue chan HashAndHeight + // fPositives is a channel to keep track of bloom filter false positives. + fPositives chan int32 + + // waitState is a channel that is empty while in the header and block + // sync modes, but when in the idle state has a "true" in it. + inWaitState chan bool } -func (s *SPVCon) Open(remoteNode string, hfn string, inTs *TxStore) error { +func OpenSPV(remoteNode string, hfn, dbfn string, + inTs *TxStore, p *chaincfg.Params) (SPVCon, error) { + // create new SPVCon + var s SPVCon + + // I should really merge SPVCon and TxStore, they're basically the same + inTs.Param = p + s.TS = inTs // copy pointer of txstore into spvcon + // open header file - err := s.openHeaderFile(headerFileName) + err := s.openHeaderFile(hfn) if err != nil { - return err + return s, err } // open TCP connection s.con, err = net.Dial("tcp", remoteNode) if err != nil { - return err + return s, err } + // assign version bits for local node s.localVersion = VERSION - s.netType = NETVERSION - s.TS = inTs + // transaction store for this SPV connection + err = inTs.OpenDB(dbfn) + if err != nil { + return s, err + } myMsgVer, err := wire.NewMsgVersionFromConn(s.con, 0, 0) if err != nil { - return err + return s, err } err = myMsgVer.AddUserAgent("test", "zero") if err != nil { - return err + return s, err } // must set this to enable SPV stuff myMsgVer.AddService(wire.SFNodeBloom) // this actually sends - n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.netType) + n, err := wire.WriteMessageN(s.con, myMsgVer, s.localVersion, s.TS.Param.Net) if err != nil { - return err + return s, err } s.WBytes += uint64(n) log.Printf("wrote %d byte version message to %s\n", n, s.con.RemoteAddr().String()) - n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.netType) + n, m, b, err := wire.ReadMessageN(s.con, s.localVersion, s.TS.Param.Net) if err != nil { - return err + return s, err } s.RBytes += uint64(n) log.Printf("got %d byte response %x\n command: %s\n", n, b, m.Command()) @@ -98,10 +120,13 @@ func (s *SPVCon) Open(remoteNode string, hfn string, inTs *TxStore) error { log.Printf("remote reports version %x (dec %d)\n", mv.ProtocolVersion, mv.ProtocolVersion) + // set remote height + s.remoteHeight = mv.LastBlock + mva := wire.NewMsgVerAck() - n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.netType) + n, err = wire.WriteMessageN(s.con, mva, s.localVersion, s.TS.Param.Net) if err != nil { - return err + return s, err } s.WBytes += uint64(n) @@ -109,8 +134,12 @@ func (s *SPVCon) Open(remoteNode string, hfn string, inTs *TxStore) error { go s.incomingMessageHandler() s.outMsgQueue = make(chan wire.Message) go s.outgoingMessageHandler() + s.mBlockQueue = make(chan HashAndHeight, 32) // queue depth 32 is a thing + s.fPositives = make(chan int32, 4000) // a block full, approx + s.inWaitState = make(chan bool, 1) + go s.fPositiveHandler() - return nil + return s, nil } func (s *SPVCon) openHeaderFile(hfn string) error { @@ -118,7 +147,7 @@ func (s *SPVCon) openHeaderFile(hfn string) error { if err != nil { if os.IsNotExist(err) { var b bytes.Buffer - err = params.GenesisBlock.Header.Serialize(&b) + err = s.TS.Param.GenesisBlock.Header.Serialize(&b) if err != nil { return err } @@ -130,7 +159,6 @@ func (s *SPVCon) openHeaderFile(hfn string) error { hfn) } } - s.headerFile, err = os.OpenFile(hfn, os.O_RDWR, 0600) if err != nil { return err @@ -148,10 +176,88 @@ func (s *SPVCon) PongBack(nonce uint64) { func (s *SPVCon) SendFilter(f *bloom.Filter) { s.outMsgQueue <- f.MsgFilterLoad() + + return +} + +// HeightFromHeader gives you the block height given a 80 byte block header +// seems like looking for the merkle root is the best way to do this +func (s *SPVCon) HeightFromHeader(query wire.BlockHeader) (uint32, error) { + // start from the most recent and work back in time; even though that's + // kind of annoying it's probably a lot faster since things tend to have + // happened recently. + + // seek to last header + s.headerMutex.Lock() + defer s.headerMutex.Unlock() + lastPos, err := s.headerFile.Seek(-80, os.SEEK_END) + if err != nil { + return 0, err + } + height := lastPos / 80 + + var current wire.BlockHeader + + for height > 0 { + // grab header from disk + err = current.Deserialize(s.headerFile) + if err != nil { + return 0, err + } + // check if merkle roots match + if current.MerkleRoot.IsEqual(&query.MerkleRoot) { + // if they do, great, return height + return uint32(height), nil + } + // skip back one header (2 because we just read one) + _, err = s.headerFile.Seek(-160, os.SEEK_CUR) + if err != nil { + return 0, err + } + // decrement height + height-- + } + // finished for loop without finding match + return 0, fmt.Errorf("Header not found on disk") +} + +// AskForTx requests a tx we heard about from an inv message. +// It's one at a time but should be fast enough. +func (s *SPVCon) AskForTx(txid wire.ShaHash) { + gdata := wire.NewMsgGetData() + inv := wire.NewInvVect(wire.InvTypeTx, &txid) + gdata.AddInvVect(inv) + s.outMsgQueue <- gdata +} + +// AskForBlock requests a merkle block we heard about from an inv message. +// We don't have it in our header file so when we get it we do both operations: +// appending and checking the header, and checking spv proofs +func (s *SPVCon) AskForBlockx(hsh wire.ShaHash) { + s.headerMutex.Lock() + defer s.headerMutex.Unlock() + + gdata := wire.NewMsgGetData() + inv := wire.NewInvVect(wire.InvTypeFilteredBlock, &hsh) + gdata.AddInvVect(inv) + + info, err := s.headerFile.Stat() // get + if err != nil { + log.Fatal(err) // crash if header file disappears + } + nextHeight := int32(info.Size() / 80) + + hah := NewRootAndHeight(hsh, nextHeight) + fmt.Printf("AskForBlock - %s height %d\n", hsh.String(), nextHeight) + s.mBlockQueue <- hah // push height and mroot of requested block on queue + s.outMsgQueue <- gdata // push request to outbox return } func (s *SPVCon) AskForHeaders() error { + s.headerMutex.Lock() + defer s.headerMutex.Unlock() + var hdr wire.BlockHeader ghdr := wire.NewMsgGetHeaders() ghdr.ProtocolVersion = s.localVersion @@ -173,6 +279,7 @@ func (s *SPVCon) AskForHeaders() error { } log.Printf("suk to offset %d (should be near the end\n", ns) + // get header from last 80 bytes of file err = hdr.Deserialize(s.headerFile) if err != nil { @@ -194,8 +301,55 @@ func (s *SPVCon) AskForHeaders() error { return nil } +func (s *SPVCon) IngestMerkleBlock(m *wire.MsgMerkleBlock) error { + txids, err := checkMBlock(m) // check self-consistency + if err != nil { + return err + } + var hah HashAndHeight + select { + case hah = <-s.mBlockQueue: // pop height off mblock queue + // not super comfortable with this but it seems to work. + if len(s.mBlockQueue) == 0 { // done and fully sync'd + s.inWaitState <- true + } + break + default: + return fmt.Errorf("Unrequested merkle block") + } + + // this verifies order, and also that the returned header fits + // into our SPV header file + newMerkBlockSha := m.Header.BlockSha() + if !hah.blockhash.IsEqual(&newMerkBlockSha) { + return fmt.Errorf("merkle block out of order error") + } + + for _, txid := range txids { + err := s.TS.AddTxid(txid, hah.height) + if err != nil { + return fmt.Errorf("Txid store error: %s\n", err.Error()) + } + } + + err = s.TS.SetDBSyncHeight(hah.height) + if err != nil { + return err + } + + return nil +} + +// IngestHeaders takes in a bunch of headers and appends them to the +// local header file, checking that they fit. If there's no headers, +// it assumes we're done and returns false. If it worked it assumes there's +// more to request and returns true. func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) { + s.headerMutex.Lock() + defer s.headerMutex.Unlock() + var err error + // seek to last header _, err = s.headerFile.Seek(-80, os.SEEK_END) if err != nil { return false, err @@ -207,6 +361,12 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) { } prevHash := last.BlockSha() + endPos, err := s.headerFile.Seek(0, os.SEEK_END) + if err != nil { + return false, err + } + tip := int32(endPos/80) - 1 // move back 1 header length to read + gotNum := int64(len(m.Headers)) if gotNum > 0 { fmt.Printf("got %d headers. Range:\n%s - %s\n", @@ -216,17 +376,11 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) { log.Printf("got 0 headers, we're probably synced up") return false, nil } - - endPos, err := s.headerFile.Seek(0, os.SEEK_END) - if err != nil { - return false, err - } - // check first header returned to make sure it fits on the end // of our header file if !m.Headers[0].PrevBlock.IsEqual(&prevHash) { // delete 100 headers if this happens! Dumb reorg. - log.Printf("possible reorg; header msg doesn't fit. points to %s, expect %s", + log.Printf("reorg? header msg doesn't fit. points to %s, expect %s", m.Headers[0].PrevBlock.String(), prevHash.String()) if endPos < 8080 { // jeez I give up, back to genesis @@ -237,22 +391,19 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) { return false, fmt.Errorf("couldn't truncate header file") } } - return false, fmt.Errorf("Truncated header file to try again") + return true, fmt.Errorf("Truncated header file to try again") } - tip := endPos / 80 - tip-- // move back header length so it can read last header for _, resphdr := range m.Headers { // write to end of file err = resphdr.Serialize(s.headerFile) if err != nil { return false, err } - // advance chain tip tip++ // check last header - worked := CheckHeader(s.headerFile, tip, params) + worked := CheckHeader(s.headerFile, tip, s.TS.Param) if !worked { if endPos < 8080 { // jeez I give up, back to genesis @@ -265,7 +416,7 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) { } // probably should disconnect from spv node at this point, // since they're giving us invalid headers. - return false, fmt.Errorf( + return true, fmt.Errorf( "Header %d - %s doesn't fit, dropping 100 headers.", resphdr.BlockSha().String(), tip) } @@ -274,40 +425,119 @@ func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) { return true, nil } -func (s *SPVCon) AskForMerkBlocks(current, last uint32) error { - var hdr wire.BlockHeader - _, err := s.headerFile.Seek(int64(current*80), os.SEEK_SET) +// HashAndHeight is needed instead of just height in case a fullnode +// responds abnormally (?) by sending out of order merkleblocks. +// we cache a merkleroot:height pair in the queue so we don't have to +// look them up from the disk. +// Also used when inv messages indicate blocks so we can add the header +// and parse the txs in one request instead of requesting headers first. +type HashAndHeight struct { + blockhash wire.ShaHash + height int32 +} + +// NewRootAndHeight saves like 2 lines. +func NewRootAndHeight(b wire.ShaHash, h int32) (hah HashAndHeight) { + hah.blockhash = b + hah.height = h + return +} + +func (s *SPVCon) PushTx(tx *wire.MsgTx) error { + txid := tx.TxSha() + err := s.TS.AddTxid(&txid, 0) if err != nil { return err } - for current < last { + _, err = s.TS.Ingest(tx) // our own tx so don't need to track relevance + if err != nil { + return err + } + s.outMsgQueue <- tx + return nil +} + +func (s *SPVCon) GetNextHeaderHeight() (int32, error) { + s.headerMutex.Lock() + defer s.headerMutex.Unlock() + info, err := s.headerFile.Stat() // get + if err != nil { + return 0, err // crash if header file disappears + } + nextHeight := int32(info.Size() / 80) + return nextHeight, nil +} + +func (s *SPVCon) RemoveHeaders(r int32) error { + endPos, err := s.headerFile.Seek(0, os.SEEK_END) + if err != nil { + return err + } + err = s.headerFile.Truncate(endPos - int64(r*80)) + if err != nil { + return fmt.Errorf("couldn't truncate header file") + } + return nil +} + +// AskForMerkBlocks requests blocks from current to last +// right now this asks for 1 block per getData message. +// Maybe it's faster to ask for many in a each message? +func (s *SPVCon) AskForMerkBlocks(current, last int32) error { + var hdr wire.BlockHeader + + nextHeight, err := s.GetNextHeaderHeight() + if err != nil { + return err + } + + fmt.Printf("have headers up to height %d\n", nextHeight-1) + // if last is 0, that means go as far as we can + if last < current { + return fmt.Errorf("MBlock range %d < %d\n", last, current) + } + fmt.Printf("will request merkleblocks %d to %d\n", current, last) + + // create initial filter + filt, err := s.TS.GimmeFilter() + if err != nil { + return err + } + // send filter + s.SendFilter(filt) + fmt.Printf("sent filter %x\n", filt.MsgFilterLoad().Filter) + s.headerMutex.Lock() + defer s.headerMutex.Unlock() + + _, err = s.headerFile.Seek(int64((current-1)*80), os.SEEK_SET) + if err != nil { + return err + } + // loop through all heights where we want merkleblocks. + for current <= last { + // load header from file err = hdr.Deserialize(s.headerFile) if err != nil { + log.Printf("Deserialize err\n") return err } - current++ bHash := hdr.BlockSha() + // create inventory we're asking for iv1 := wire.NewInvVect(wire.InvTypeFilteredBlock, &bHash) gdataMsg := wire.NewMsgGetData() + // add inventory err = gdataMsg.AddInvVect(iv1) if err != nil { return err } + hah := NewRootAndHeight(hdr.BlockSha(), current) s.outMsgQueue <- gdataMsg + s.mBlockQueue <- hah // push height and mroot of requested block on queue + current++ } - - return nil -} - -func sendMBReq(cn net.Conn, blkhash wire.ShaHash) error { - iv1 := wire.NewInvVect(wire.InvTypeFilteredBlock, &blkhash) - gdataB := wire.NewMsgGetData() - _ = gdataB.AddInvVect(iv1) - n, err := wire.WriteMessageN(cn, gdataB, VERSION, NETVERSION) - if err != nil { - return err - } - log.Printf("sent %d byte block request\n", n) + // done syncing blocks known in header file, ask for new headers we missed + // s.AskForHeaders() + // don't need this -- will sync to end regardless return nil } diff --git a/uspv/header.go b/uspv/header.go index 6f9c18f9..40983e36 100644 --- a/uspv/header.go +++ b/uspv/header.go @@ -23,7 +23,7 @@ import ( const ( targetTimespan = time.Hour * 24 * 14 targetSpacing = time.Minute * 10 - epochLength = int64(targetTimespan / targetSpacing) + epochLength = int32(targetTimespan / targetSpacing) // 2016 maxDiffAdjust = 4 minRetargetTimespan = int64(targetTimespan / maxDiffAdjust) maxRetargetTimespan = int64(targetTimespan * maxDiffAdjust) @@ -90,7 +90,7 @@ func calcDiffAdjust(start, end wire.BlockHeader, p *chaincfg.Params) uint32 { return blockchain.BigToCompact(newTarget) } -func CheckHeader(r io.ReadSeeker, height int64, p *chaincfg.Params) bool { +func CheckHeader(r io.ReadSeeker, height int32, p *chaincfg.Params) bool { var err error var cur, prev, epochStart wire.BlockHeader // don't try to verfy the genesis block. That way madness lies. @@ -100,7 +100,7 @@ func CheckHeader(r io.ReadSeeker, height int64, p *chaincfg.Params) bool { // initial load of headers // load epochstart, previous and current. // get the header from the epoch start, up to 2016 blocks ago - _, err = r.Seek(80*(height-(height%epochLength)), os.SEEK_SET) + _, err = r.Seek(int64(80*(height-(height%epochLength))), os.SEEK_SET) if err != nil { log.Printf(err.Error()) return false @@ -113,7 +113,7 @@ func CheckHeader(r io.ReadSeeker, height int64, p *chaincfg.Params) bool { // log.Printf("start epoch at height %d ", height-(height%epochLength)) // seek to n-1 header - _, err = r.Seek(80*(height-1), os.SEEK_SET) + _, err = r.Seek(int64(80*(height-1)), os.SEEK_SET) if err != nil { log.Printf(err.Error()) return false @@ -125,7 +125,7 @@ func CheckHeader(r io.ReadSeeker, height int64, p *chaincfg.Params) bool { return false } // seek to curHeight header and read in - _, err = r.Seek(80*(height), os.SEEK_SET) + _, err = r.Seek(int64(80*(height)), os.SEEK_SET) if err != nil { log.Printf(err.Error()) return false @@ -184,7 +184,7 @@ difficulty adjustments, and that they all link in to each other properly. This is the only blockchain technology in the whole code base. Returns false if anything bad happens. Returns true if the range checks out with no errors. */ -func CheckRange(r io.ReadSeeker, first, last int64, p *chaincfg.Params) bool { +func CheckRange(r io.ReadSeeker, first, last int32, p *chaincfg.Params) bool { for i := first; i <= last; i++ { if !CheckHeader(r, i, p) { return false diff --git a/uspv/keyfileio.go b/uspv/keyfileio.go new file mode 100644 index 00000000..ec359edf --- /dev/null +++ b/uspv/keyfileio.go @@ -0,0 +1,192 @@ +package uspv + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "io/ioutil" + "os" + "strings" + + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcutil/hdkeychain" + "github.com/howeyc/gopass" + "golang.org/x/crypto/nacl/secretbox" + "golang.org/x/crypto/scrypt" +) + +// warning! look at those imports! crypto! hopefully this works! + +/* on-disk stored keys are 32bytes. This is good for ed25519 private keys, +for seeds for bip32, for individual secp256k1 priv keys, and so on. +32 bytes is enough for anyone. +If you want fewer bytes, put some zeroes at the end */ + +// LoadKeyFromFileInteractive opens the file 'filename' and presents a +// keyboard prompt for the passphrase to decrypt it. It returns the +// key if decryption works, or errors out. +func LoadKeyFromFileInteractive(filename string) (*[32]byte, error) { + a, err := os.Stat(filename) + if err != nil { + return new([32]byte), err + } + if a.Size() < 80 { // there can't be a password... + return LoadKeyFromFileArg(filename, nil) + } + fmt.Printf("passphrase: ") + pass := gopass.GetPasswd() + fmt.Printf("\n") + return LoadKeyFromFileArg(filename, pass) +} + +// LoadKeyFromFileArg opens the file and returns the key. If the key is +// unencrypted it will ignore the password argument. +func LoadKeyFromFileArg(filename string, pass []byte) (*[32]byte, error) { + priv32 := new([32]byte) + keyhex, err := ioutil.ReadFile(filename) + if err != nil { + return priv32, err + } + keyhex = []byte(strings.TrimSpace(string(keyhex))) + enckey, err := hex.DecodeString(string(keyhex)) + if err != nil { + return priv32, err + } + + if len(enckey) == 32 { // UNencrypted key, length 32 + fmt.Printf("WARNING!! Key file not encrypted!!\n") + fmt.Printf("Anyone who can read the key file can take everything!\n") + fmt.Printf("You should start over and use a good passphrase!\n") + copy(priv32[:], enckey[:]) + return priv32, nil + } + // enckey should be 72 bytes. 24 for scrypt salt/box nonce, + // 16 for box auth + if len(enckey) != 72 { + return priv32, fmt.Errorf("Key length error for %s ", filename) + } + // enckey is actually encrypted, get derived key from pass and salt + // first extract salt + salt := new([24]byte) // salt (also nonce for secretbox) + dk32 := new([32]byte) // derived key array + copy(salt[:], enckey[:24]) // first 24 bytes are scrypt salt/box nonce + + dk, err := scrypt.Key(pass, salt[:], 16384, 8, 1, 32) // derive key + if err != nil { + return priv32, err + } + copy(dk32[:], dk[:]) // copy into fixed size array + + // nonce for secretbox is the same as scrypt salt. Seems fine. Really. + priv, worked := secretbox.Open(nil, enckey[24:], salt, dk32) + if worked != true { + return priv32, fmt.Errorf("Decryption failed for %s ", filename) + } + copy(priv32[:], priv[:]) //copy decrypted private key into array + + priv = nil // this probably doesn't do anything but... eh why not + return priv32, nil +} + +// saves a 32 byte key to file, prompting for passphrase. +// if user enters empty passphrase (hits enter twice), will be saved +// in the clear. +func SaveKeyToFileInteractive(filename string, priv32 *[32]byte) error { + var match bool + var pass1, pass2 []byte + for match != true { + fmt.Printf("passphrase: ") + pass1 = gopass.GetPasswd() + fmt.Printf("repeat passphrase: ") + pass2 = gopass.GetPasswd() + if string(pass1) == string(pass2) { + match = true + } else { + fmt.Printf("user input error. Try again gl hf dd.\n") + } + } + fmt.Printf("\n") + return SaveKeyToFileArg(filename, priv32, pass1) +} + +// saves a 32 byte key to a file, encrypting with pass. +// if pass is nil or zero length, doesn't encrypt and just saves in hex. +func SaveKeyToFileArg(filename string, priv32 *[32]byte, pass []byte) error { + if len(pass) == 0 { // zero-length pass, save unencrypted + keyhex := fmt.Sprintf("%x\n", priv32[:]) + err := ioutil.WriteFile(filename, []byte(keyhex), 0600) + if err != nil { + return err + } + fmt.Printf("WARNING!! Key file not encrypted!!\n") + fmt.Printf("Anyone who can read the key file can take everything!\n") + fmt.Printf("You should start over and use a good passphrase!\n") + fmt.Printf("Saved unencrypted key at %s\n", filename) + return nil + } + + salt := new([24]byte) // salt for scrypt / nonce for secretbox + dk32 := new([32]byte) // derived key from scrypt + + //get 24 random bytes for scrypt salt (and secretbox nonce) + _, err := rand.Read(salt[:]) + if err != nil { + return err + } + // next use the pass and salt to make a 32-byte derived key + dk, err := scrypt.Key(pass, salt[:], 16384, 8, 1, 32) + if err != nil { + return err + } + copy(dk32[:], dk[:]) + + enckey := append(salt[:], secretbox.Seal(nil, priv32[:], salt, dk32)...) + // enckey = append(salt, enckey...) + keyhex := fmt.Sprintf("%x\n", enckey) + + err = ioutil.WriteFile(filename, []byte(keyhex), 0600) + if err != nil { + return err + } + fmt.Printf("Wrote encrypted key to %s\n", filename) + return nil +} + +// ReadKeyFileToECPriv returns an extendedkey from a file. +// If there's no file there, it'll make one. If there's a password needed, +// it'll prompt for one. One stop function. +func ReadKeyFileToECPriv( + filename string, p *chaincfg.Params) (*hdkeychain.ExtendedKey, error) { + key32 := new([32]byte) + _, err := os.Stat(filename) + if err != nil { + if os.IsNotExist(err) { + // no key found, generate and save one + fmt.Printf("No file %s, generating.\n", filename) + rn, err := hdkeychain.GenerateSeed(32) + if err != nil { + return nil, err + } + copy(key32[:], rn[:]) + err = SaveKeyToFileInteractive(filename, key32) + if err != nil { + return nil, err + } + } else { + // unknown error, crash + fmt.Printf("unknown\n") + return nil, err + } + } + + key, err := LoadKeyFromFileInteractive(filename) + if err != nil { + return nil, err + } + + rootpriv, err := hdkeychain.NewMaster(key[:], p) + if err != nil { + return nil, err + } + return rootpriv, nil +} diff --git a/uspv/mblock.go b/uspv/mblock.go index dc71608c..823ec805 100644 --- a/uspv/mblock.go +++ b/uspv/mblock.go @@ -65,9 +65,8 @@ func inDeadZone(pos, size uint32) bool { return pos > last } -// take in a merkle block, parse through it, and return the -// txids that they're trying to tell us about. If there's any problem -// return an error. +// take in a merkle block, parse through it, and return txids indicated +// If there's any problem return an error. Checks self-consistency only. // doing it with a stack instead of recursion. Because... // OK I don't know why I'm just not in to recursion OK? func checkMBlock(m *wire.MsgMerkleBlock) ([]*wire.ShaHash, error) { diff --git a/uspv/msghandler.go b/uspv/msghandler.go index 76e7f1b9..926ae5da 100644 --- a/uspv/msghandler.go +++ b/uspv/msghandler.go @@ -3,13 +3,14 @@ package uspv import ( "fmt" "log" + "os" "github.com/btcsuite/btcd/wire" ) func (s *SPVCon) incomingMessageHandler() { for { - n, xm, _, err := wire.ReadMessageN(s.con, s.localVersion, s.netType) + n, xm, _, err := wire.ReadMessageN(s.con, s.localVersion, s.TS.Param.Net) if err != nil { log.Printf("ReadMessageN error. Disconnecting: %s\n", err.Error()) return @@ -26,47 +27,31 @@ func (s *SPVCon) incomingMessageHandler() { case *wire.MsgAddr: log.Printf("got %d addresses.\n", len(m.AddrList)) case *wire.MsgPing: - log.Printf("Got a ping message. We should pong back or they will kick us off.") - s.PongBack(m.Nonce) + // log.Printf("Got a ping message. We should pong back or they will kick us off.") + go s.PongBack(m.Nonce) case *wire.MsgPong: log.Printf("Got a pong response. OK.\n") case *wire.MsgMerkleBlock: - - // log.Printf("Got merkle block message. Will verify.\n") - // fmt.Printf("%d flag bytes, %d txs, %d hashes", - // len(m.Flags), m.Transactions, len(m.Hashes)) - txids, err := checkMBlock(m) + err = s.IngestMerkleBlock(m) if err != nil { log.Printf("Merkle block error: %s\n", err.Error()) - return - // continue + continue } - fmt.Printf(" got %d txs ", len(txids)) - // fmt.Printf(" = got %d txs from block %s\n", - // len(txids), m.Header.BlockSha().String()) - for _, txid := range txids { - err := s.TS.AddTxid(txid) - if err != nil { - log.Printf("Txid store error: %s\n", err.Error()) - } + case *wire.MsgHeaders: // concurrent because we keep asking for blocks + go s.HeaderHandler(m) + case *wire.MsgTx: // not concurrent! txs must be in order + s.TxHandler(m) + case *wire.MsgReject: + log.Printf("Rejected! cmd: %s code: %s tx: %s reason: %s", + m.Cmd, m.Code.String(), m.Hash.String(), m.Reason) + case *wire.MsgInv: + s.InvHandler(m) + case *wire.MsgNotFound: + log.Printf("Got not found response from remote:") + for i, thing := range m.InvList { + log.Printf("\t$d) %s: %s", i, thing.Type, thing.Hash) } - // nextReq <- true - case *wire.MsgHeaders: - moar, err := s.IngestHeaders(m) - if err != nil { - log.Printf("Header error: %s\n", err.Error()) - return - } - if moar { - s.AskForHeaders() - } - case *wire.MsgTx: - err := s.TS.IngestTx(m) - if err != nil { - log.Printf("Incoming Tx error: %s\n", err.Error()) - } - // log.Printf("Got tx %s\n", m.TxSha().String()) default: log.Printf("Got unknown message type %s\n", m.Command()) } @@ -79,7 +64,7 @@ func (s *SPVCon) incomingMessageHandler() { func (s *SPVCon) outgoingMessageHandler() { for { msg := <-s.outMsgQueue - n, err := wire.WriteMessageN(s.con, msg, s.localVersion, s.netType) + n, err := wire.WriteMessageN(s.con, msg, s.localVersion, s.TS.Param.Net) if err != nil { log.Printf("Write message error: %s", err.Error()) } @@ -87,3 +72,109 @@ func (s *SPVCon) outgoingMessageHandler() { } return } + +// fPositiveHandler monitors false positives and when it gets enough of them, +// +func (s *SPVCon) fPositiveHandler() { + var fpAccumulator int32 + for { + fpAccumulator += <-s.fPositives // blocks here + if fpAccumulator > 7 { + filt, err := s.TS.GimmeFilter() + if err != nil { + log.Printf("Filter creation error: %s\n", err.Error()) + log.Printf("uhoh, crashing filter handler") + return + } + + // send filter + s.SendFilter(filt) + fmt.Printf("sent filter %x\n", filt.MsgFilterLoad().Filter) + // clear the channel + for len(s.fPositives) != 0 { + fpAccumulator += <-s.fPositives + } + fmt.Printf("reset %d false positives\n", fpAccumulator) + // reset accumulator + fpAccumulator = 0 + } + } +} + +func (s *SPVCon) HeaderHandler(m *wire.MsgHeaders) { + moar, err := s.IngestHeaders(m) + if err != nil { + log.Printf("Header error: %s\n", err.Error()) + return + } + // if we got post DB syncheight headers, get merkleblocks for them + // this is always true except for first pre-birthday sync + + // checked header length, start req for more if needed + if moar { + s.AskForHeaders() + } else { // no moar, done w/ headers, get merkleblocks + s.headerMutex.Lock() + endPos, err := s.headerFile.Seek(0, os.SEEK_END) + if err != nil { + log.Printf("Header error: %s", err.Error()) + return + } + s.headerMutex.Unlock() + tip := int32(endPos/80) - 1 // move back 1 header length to read + syncTip, err := s.TS.GetDBSyncHeight() + if err != nil { + log.Printf("syncTip error: %s", err.Error()) + return + } + if syncTip < tip { + fmt.Printf("syncTip %d headerTip %d\n", syncTip, tip) + err = s.AskForMerkBlocks(syncTip+1, tip) + if err != nil { + log.Printf("AskForMerkBlocks error: %s", err.Error()) + return + } + } + } +} + +func (s *SPVCon) TxHandler(m *wire.MsgTx) { + hits, err := s.TS.Ingest(m) + if err != nil { + log.Printf("Incoming Tx error: %s\n", err.Error()) + } + if hits == 0 { + log.Printf("tx %s had no hits, filter false positive.", + m.TxSha().String()) + s.fPositives <- 1 // add one false positive to chan + } else { + log.Printf("tx %s ingested and matches %d utxo/adrs.", + m.TxSha().String(), hits) + } +} + +func (s *SPVCon) InvHandler(m *wire.MsgInv) { + log.Printf("got inv. Contains:\n") + for i, thing := range m.InvList { + log.Printf("\t%d)%s : %s", + i, thing.Type.String(), thing.Hash.String()) + if thing.Type == wire.InvTypeTx { // new tx, ingest + s.TS.OKTxids[thing.Hash] = 0 // unconfirmed + s.AskForTx(thing.Hash) + } + if thing.Type == wire.InvTypeBlock { // new block what to do? + select { + case <-s.inWaitState: + // start getting headers + fmt.Printf("asking for headers due to inv block\n") + err := s.AskForHeaders() + if err != nil { + log.Printf("AskForHeaders error: %s", err.Error()) + } + default: + // drop it as if its component particles had high thermal energies + fmt.Printf("inv block but ignoring; not synched\n") + } + } + } +} diff --git a/uspv/sortsignsend.go b/uspv/sortsignsend.go new file mode 100644 index 00000000..4aa9bc26 --- /dev/null +++ b/uspv/sortsignsend.go @@ -0,0 +1,56 @@ +package uspv + +import ( + "bytes" + "fmt" + + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil/hdkeychain" + "github.com/btcsuite/btcutil/txsort" +) + +func (t *TxStore) SignThis(tx *wire.MsgTx) error { + fmt.Printf("-= SignThis =-\n") + + // sort tx before signing. + txsort.InPlaceSort(tx) + + sigs := make([][]byte, len(tx.TxIn)) + // first iterate over each input + for j, in := range tx.TxIn { + for k := uint32(0); k < uint32(len(t.Adrs)); k++ { + child, err := t.rootPrivKey.Child(k + hdkeychain.HardenedKeyStart) + if err != nil { + return err + } + myadr, err := child.Address(t.Param) + if err != nil { + return err + } + adrScript, err := txscript.PayToAddrScript(myadr) + if err != nil { + return err + } + if bytes.Equal(adrScript, in.SignatureScript) { + fmt.Printf("Hit; key %d matches input %d. Signing.\n", k, j) + priv, err := child.ECPrivKey() + if err != nil { + return err + } + sigs[j], err = txscript.SignatureScript( + tx, j, in.SignatureScript, txscript.SigHashAll, priv, true) + if err != nil { + return err + } + break + } + } + } + for i, s := range sigs { + if s != nil { + tx.TxIn[i].SignatureScript = s + } + } + return nil +} diff --git a/uspv/txstore.go b/uspv/txstore.go index 806966a1..a5062a3b 100644 --- a/uspv/txstore.go +++ b/uspv/txstore.go @@ -2,49 +2,72 @@ package uspv import ( "bytes" + "encoding/binary" "fmt" "log" + "github.com/btcsuite/btcd/chaincfg" + + "github.com/boltdb/bolt" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil/bloom" + "github.com/btcsuite/btcutil/hdkeychain" ) type TxStore struct { - KnownTxids []*wire.ShaHash - Utxos []Utxo // stacks on stacks - Sum int64 // racks on racks - Adrs []MyAdr // endeavouring to acquire capital + OKTxids map[wire.ShaHash]int32 // known good txids and their heights + + Adrs []MyAdr // endeavouring to acquire capital + StateDB *bolt.DB // place to write all this down + + // Params live here, not SCon + Param *chaincfg.Params // network parameters (testnet3, testnetL) + + // From here, comes everything. It's a secret to everybody. + rootPrivKey *hdkeychain.ExtendedKey } type Utxo struct { // cash money. - // combo of outpoint and txout which has all the info needed to spend - Op wire.OutPoint - Txo wire.TxOut - AtHeight uint32 // block height where this tx was confirmed, 0 for unconf + Op wire.OutPoint // where + + // all the info needed to spend + AtHeight int32 // block height where this tx was confirmed, 0 for unconf KeyIdx uint32 // index for private key needed to sign / spend + Value int64 // higher is better + + // IsCoinbase bool // can't spend for a while +} + +// Stxo is a utxo that has moved on. +type Stxo struct { + Utxo // when it used to be a utxo + SpendHeight int32 // height at which it met its demise + SpendTxid wire.ShaHash // the tx that consumed it } type MyAdr struct { // an address I have the private key for - btcutil.Address + PkhAdr btcutil.Address KeyIdx uint32 // index for private key needed to sign / spend + // ^^ this is kindof redundant because it'll just be their position + // inside the Adrs slice, right? leave for now } -// add addresses into the TxStore -func (t *TxStore) AddAdr(a btcutil.Address, kidx uint32) { - var ma MyAdr - ma.Address = a - ma.KeyIdx = kidx - t.Adrs = append(t.Adrs, ma) - return +func NewTxStore(rootkey *hdkeychain.ExtendedKey, p *chaincfg.Params) TxStore { + var txs TxStore + txs.rootPrivKey = rootkey + txs.Param = p + txs.OKTxids = make(map[wire.ShaHash]int32) + return txs } // add txid of interest -func (t *TxStore) AddTxid(txid *wire.ShaHash) error { +func (t *TxStore) AddTxid(txid *wire.ShaHash, height int32) error { if txid == nil { return fmt.Errorf("tried to add nil txid") } - t.KnownTxids = append(t.KnownTxids, txid) + log.Printf("added %s to OKTxids at height %d\n", txid.String(), height) + t.OKTxids[*txid] = height return nil } @@ -53,93 +76,235 @@ func (t *TxStore) GimmeFilter() (*bloom.Filter, error) { if len(t.Adrs) == 0 { return nil, fmt.Errorf("no addresses to filter for") } - f := bloom.NewFilter(uint32(len(t.Adrs)), 0, 0.001, wire.BloomUpdateNone) - for _, a := range t.Adrs { - f.Add(a.ScriptAddress()) + // add addresses to look for incoming + nutxo, err := t.NumUtxos() + if err != nil { + return nil, err } + + elem := uint32(len(t.Adrs)) + nutxo + f := bloom.NewFilter(elem, 0, 0.000001, wire.BloomUpdateAll) + for _, a := range t.Adrs { + f.Add(a.PkhAdr.ScriptAddress()) + } + + // get all utxos to add outpoints to filter + allUtxos, err := t.GetAllUtxos() + if err != nil { + return nil, err + } + + for _, u := range allUtxos { + f.AddOutPoint(&u.Op) + } + return f, nil } -// Ingest a tx into wallet, dealing with both gains and losses -func (t *TxStore) IngestTx(tx *wire.MsgTx) error { - var match bool - inTxid := tx.TxSha() - for _, ktxid := range t.KnownTxids { - if inTxid.IsEqual(ktxid) { - match = true - break // found tx match, +// TxToString prints out some info about a transaction. for testing / debugging +func TxToString(tx *wire.MsgTx) string { + str := "\t\t\t - Tx - \n" + for i, in := range tx.TxIn { + str += fmt.Sprintf("Input %d: %s\n", i, in.PreviousOutPoint.String()) + str += fmt.Sprintf("SigScript for input %d: %x\n", i, in.SignatureScript) + } + for i, out := range tx.TxOut { + if out != nil { + str += fmt.Sprintf("\toutput %d script: %x amt: %d\n", + i, out.PkScript, out.Value) + } else { + str += fmt.Sprintf("output %d nil (WARNING)\n", i) } } - if !match { - return fmt.Errorf("we don't care about tx %s", inTxid.String()) - } + return str +} - err := t.AbsorbTx(tx) +// need this because before I was comparing pointers maybe? +// so they were the same outpoint but stored in 2 places so false negative? +func OutPointsEqual(a, b wire.OutPoint) bool { + if !a.Hash.IsEqual(&b.Hash) { + return false + } + return a.Index == b.Index +} + +/*----- serialization for tx outputs ------- */ + +// outPointToBytes turns an outpoint into 36 bytes. +func outPointToBytes(op *wire.OutPoint) ([]byte, error) { + var buf bytes.Buffer + _, err := buf.Write(op.Hash.Bytes()) if err != nil { - return err + return nil, err } - err = t.ExpellTx(tx) + // write 4 byte outpoint index within the tx to spend + err = binary.Write(&buf, binary.BigEndian, op.Index) if err != nil { - return err + return nil, err } - // fmt.Printf("ingested tx %s total amt %d\n", inTxid.String(), t.Sum) - return nil + return buf.Bytes(), nil } -// Absorb money into wallet from a tx -func (t *TxStore) AbsorbTx(tx *wire.MsgTx) error { - if tx == nil { - return fmt.Errorf("Tried to add nil tx") +// ToBytes turns a Utxo into some bytes. +// note that the txid is the first 36 bytes and in our use cases will be stripped +// off, but is left here for other applications +func (u *Utxo) ToBytes() ([]byte, error) { + var buf bytes.Buffer + // write 32 byte txid of the utxo + _, err := buf.Write(u.Op.Hash.Bytes()) + if err != nil { + return nil, err } - var hits uint32 - var acq int64 - // check if any of the tx's outputs match my adrs - for i, out := range tx.TxOut { // in each output of tx - for _, a := range t.Adrs { // compare to each adr we have - // more correct would be to check for full script - // contains could have false positive? (p2sh/p2pkh same hash ..?) - if bytes.Contains(out.PkScript, a.ScriptAddress()) { // hit - hits++ - acq += out.Value - var newu Utxo - newu.KeyIdx = a.KeyIdx - newu.Txo = *out - - var newop wire.OutPoint - newop.Hash = tx.TxSha() - newop.Index = uint32(i) - newu.Op = newop - - t.Utxos = append(t.Utxos, newu) - break - } - } + // write 4 byte outpoint index within the tx to spend + err = binary.Write(&buf, binary.BigEndian, u.Op.Index) + if err != nil { + return nil, err } - log.Printf("%d hits, acquired %d", hits, acq) - t.Sum += acq - return nil + // write 4 byte height of utxo + err = binary.Write(&buf, binary.BigEndian, u.AtHeight) + if err != nil { + return nil, err + } + // write 4 byte key index of utxo + err = binary.Write(&buf, binary.BigEndian, u.KeyIdx) + if err != nil { + return nil, err + } + // write 8 byte amount of money at the utxo + err = binary.Write(&buf, binary.BigEndian, u.Value) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil } -// Expell money from wallet due to a tx -func (t *TxStore) ExpellTx(tx *wire.MsgTx) error { - if tx == nil { - return fmt.Errorf("Tried to add nil tx") +// UtxoFromBytes turns bytes into a Utxo. Note it wants the txid and outindex +// in the first 36 bytes, which isn't stored that way in the boldDB, +// but can be easily appended. +func UtxoFromBytes(b []byte) (Utxo, error) { + var u Utxo + if b == nil { + return u, fmt.Errorf("nil input slice") } - var hits uint32 - var loss int64 - - for _, in := range tx.TxIn { - for i, myutxo := range t.Utxos { - if myutxo.Op == in.PreviousOutPoint { - - hits++ - loss += myutxo.Txo.Value - // delete from my utxo set - t.Utxos = append(t.Utxos[:i], t.Utxos[i+1:]...) - } - } + buf := bytes.NewBuffer(b) + if buf.Len() < 52 { // utxos are 52 bytes + return u, fmt.Errorf("Got %d bytes for utxo, expect 52", buf.Len()) } - log.Printf("%d hits, lost %d", hits, loss) - t.Sum -= loss - return nil + // read 32 byte txid + err := u.Op.Hash.SetBytes(buf.Next(32)) + if err != nil { + return u, err + } + // read 4 byte outpoint index within the tx to spend + err = binary.Read(buf, binary.BigEndian, &u.Op.Index) + if err != nil { + return u, err + } + // read 4 byte height of utxo + err = binary.Read(buf, binary.BigEndian, &u.AtHeight) + if err != nil { + return u, err + } + // read 4 byte key index of utxo + err = binary.Read(buf, binary.BigEndian, &u.KeyIdx) + if err != nil { + return u, err + } + // read 8 byte amount of money at the utxo + err = binary.Read(buf, binary.BigEndian, &u.Value) + if err != nil { + return u, err + } + return u, nil +} + +// ToBytes turns an Stxo into some bytes. +// outpoint txid, outpoint idx, height, key idx, amt, spendheight, spendtxid +func (s *Stxo) ToBytes() ([]byte, error) { + var buf bytes.Buffer + // write 32 byte txid of the utxo + _, err := buf.Write(s.Op.Hash.Bytes()) + if err != nil { + return nil, err + } + // write 4 byte outpoint index within the tx to spend + err = binary.Write(&buf, binary.BigEndian, s.Op.Index) + if err != nil { + return nil, err + } + // write 4 byte height of utxo + err = binary.Write(&buf, binary.BigEndian, s.AtHeight) + if err != nil { + return nil, err + } + // write 4 byte key index of utxo + err = binary.Write(&buf, binary.BigEndian, s.KeyIdx) + if err != nil { + return nil, err + } + // write 8 byte amount of money at the utxo + err = binary.Write(&buf, binary.BigEndian, s.Value) + if err != nil { + return nil, err + } + // write 4 byte height where the txo was spent + err = binary.Write(&buf, binary.BigEndian, s.SpendHeight) + if err != nil { + return nil, err + } + // write 32 byte txid of the spending transaction + _, err = buf.Write(s.SpendTxid.Bytes()) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// StxoFromBytes turns bytes into a Stxo. +func StxoFromBytes(b []byte) (Stxo, error) { + var s Stxo + if b == nil { + return s, fmt.Errorf("nil input slice") + } + buf := bytes.NewBuffer(b) + if buf.Len() < 88 { // stxos are 88 bytes + return s, fmt.Errorf("Got %d bytes for stxo, expect 88", buf.Len()) + } + // read 32 byte txid + err := s.Op.Hash.SetBytes(buf.Next(32)) + if err != nil { + return s, err + } + // read 4 byte outpoint index within the tx to spend + err = binary.Read(buf, binary.BigEndian, &s.Op.Index) + if err != nil { + return s, err + } + // read 4 byte height of utxo + err = binary.Read(buf, binary.BigEndian, &s.AtHeight) + if err != nil { + return s, err + } + // read 4 byte key index of utxo + err = binary.Read(buf, binary.BigEndian, &s.KeyIdx) + if err != nil { + return s, err + } + // read 8 byte amount of money at the utxo + err = binary.Read(buf, binary.BigEndian, &s.Value) + if err != nil { + return s, err + } + // read 4 byte spend height + err = binary.Read(buf, binary.BigEndian, &s.SpendHeight) + if err != nil { + return s, err + } + // read 32 byte txid + err = s.SpendTxid.SetBytes(buf.Next(32)) + if err != nil { + return s, err + } + return s, nil } diff --git a/uspv/utxodb.go b/uspv/utxodb.go new file mode 100644 index 00000000..28475470 --- /dev/null +++ b/uspv/utxodb.go @@ -0,0 +1,363 @@ +package uspv + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/btcsuite/btcd/blockchain" + + "github.com/btcsuite/btcd/txscript" + + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/btcsuite/btcutil/hdkeychain" + + "github.com/boltdb/bolt" +) + +var ( + BKTUtxos = []byte("DuffelBag") // leave the rest to collect interest + BKTStxos = []byte("SpentTxs") // for bookkeeping + BKTTxns = []byte("Txns") // all txs we care about, for replays + BKTState = []byte("MiscState") // last state of DB + // these are in the state bucket + KEYNumKeys = []byte("NumKeys") // number of keys used + KEYTipHeight = []byte("TipHeight") // height synced to +) + +func (ts *TxStore) OpenDB(filename string) error { + var err error + var numKeys uint32 + ts.StateDB, err = bolt.Open(filename, 0644, nil) + if err != nil { + return err + } + // create buckets if they're not already there + err = ts.StateDB.Update(func(btx *bolt.Tx) error { + _, err = btx.CreateBucketIfNotExists(BKTUtxos) + if err != nil { + return err + } + _, err = btx.CreateBucketIfNotExists(BKTStxos) + if err != nil { + return err + } + _, err = btx.CreateBucketIfNotExists(BKTTxns) + if err != nil { + return err + } + sta, err := btx.CreateBucketIfNotExists(BKTState) + if err != nil { + return err + } + + numKeysBytes := sta.Get(KEYNumKeys) + if numKeysBytes != nil { // NumKeys exists, read into uint32 + buf := bytes.NewBuffer(numKeysBytes) + err := binary.Read(buf, binary.BigEndian, &numKeys) + if err != nil { + return err + } + fmt.Printf("db says %d keys\n", numKeys) + } else { // no adrs yet, make it 1 (why...?) + numKeys = 1 + var buf bytes.Buffer + err = binary.Write(&buf, binary.BigEndian, numKeys) + if err != nil { + return err + } + err = sta.Put(KEYNumKeys, buf.Bytes()) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return err + } + return ts.PopulateAdrs(numKeys) +} + +// NewAdr creates a new, never before seen address, and increments the +// DB counter as well as putting it in the ram Adrs store, and returns it +func (ts *TxStore) NewAdr() (*btcutil.AddressPubKeyHash, error) { + if ts.Param == nil { + return nil, fmt.Errorf("nil param") + } + n := uint32(len(ts.Adrs)) + priv, err := ts.rootPrivKey.Child(n + hdkeychain.HardenedKeyStart) + if err != nil { + return nil, err + } + + newAdr, err := priv.Address(ts.Param) + if err != nil { + return nil, err + } + + // total number of keys (now +1) into 4 bytes + var buf bytes.Buffer + err = binary.Write(&buf, binary.BigEndian, n+1) + if err != nil { + return nil, err + } + + // write to db file + err = ts.StateDB.Update(func(btx *bolt.Tx) error { + sta := btx.Bucket(BKTState) + return sta.Put(KEYNumKeys, buf.Bytes()) + }) + if err != nil { + return nil, err + } + // add in to ram. + var ma MyAdr + ma.PkhAdr = newAdr + ma.KeyIdx = n + ts.Adrs = append(ts.Adrs, ma) + + return newAdr, nil +} + +// SetBDay sets the birthday (birth height) of the db (really keyfile) +func (ts *TxStore) SetDBSyncHeight(n int32) error { + var buf bytes.Buffer + _ = binary.Write(&buf, binary.BigEndian, n) + + return ts.StateDB.Update(func(btx *bolt.Tx) error { + sta := btx.Bucket(BKTState) + return sta.Put(KEYTipHeight, buf.Bytes()) + }) +} + +// SyncHeight returns the chain height to which the db has synced +func (ts *TxStore) GetDBSyncHeight() (int32, error) { + var n int32 + err := ts.StateDB.View(func(btx *bolt.Tx) error { + sta := btx.Bucket(BKTState) + if sta == nil { + return fmt.Errorf("no state") + } + t := sta.Get(KEYTipHeight) + + if t == nil { // no height written, so 0 + return nil + } + + // read 4 byte tip height to n + err := binary.Read(bytes.NewBuffer(t), binary.BigEndian, &n) + if err != nil { + return err + } + return nil + }) + if err != nil { + return 0, err + } + return n, nil +} + +// NumUtxos returns the number of utxos in the DB. +func (ts *TxStore) NumUtxos() (uint32, error) { + var n uint32 + err := ts.StateDB.View(func(btx *bolt.Tx) error { + duf := btx.Bucket(BKTUtxos) + if duf == nil { + return fmt.Errorf("no duffel bag") + } + stats := duf.Stats() + n = uint32(stats.KeyN) + return nil + }) + if err != nil { + return 0, err + } + return n, nil +} + +func (ts *TxStore) GetAllUtxos() ([]*Utxo, error) { + var utxos []*Utxo + err := ts.StateDB.View(func(btx *bolt.Tx) error { + duf := btx.Bucket(BKTUtxos) + if duf == nil { + return fmt.Errorf("no duffel bag") + } + return duf.ForEach(func(k, v []byte) error { + // have to copy k and v here, otherwise append will crash it. + // not quite sure why but append does weird stuff I guess. + + // create a new utxo + x := make([]byte, len(k)+len(v)) + copy(x, k) + copy(x[len(k):], v) + newU, err := UtxoFromBytes(x) + if err != nil { + return err + } + // and add it to ram + utxos = append(utxos, &newU) + + return nil + }) + + return nil + }) + if err != nil { + return nil, err + } + return utxos, nil +} + +// PopulateAdrs just puts a bunch of adrs in ram; it doesn't touch the DB +func (ts *TxStore) PopulateAdrs(lastKey uint32) error { + for k := uint32(0); k < lastKey; k++ { + + priv, err := ts.rootPrivKey.Child(k + hdkeychain.HardenedKeyStart) + if err != nil { + return err + } + + newAdr, err := priv.Address(ts.Param) + if err != nil { + return err + } + var ma MyAdr + ma.PkhAdr = newAdr + ma.KeyIdx = k + ts.Adrs = append(ts.Adrs, ma) + + } + return nil +} + +// Ingest puts a tx into the DB atomically. This can result in a +// gain, a loss, or no result. Gain or loss in satoshis is returned. +func (ts *TxStore) Ingest(tx *wire.MsgTx) (uint32, error) { + var hits uint32 + var err error + var spentOPs [][]byte + var nUtxoBytes [][]byte + + // first check that we have a height and tx has been SPV OK'd + inTxid := tx.TxSha() + height, ok := ts.OKTxids[inTxid] + if !ok { + return hits, fmt.Errorf("Ingest error: tx %s not in OKTxids.", + inTxid.String()) + } + + // tx has been OK'd by SPV; check tx sanity + utilTx := btcutil.NewTx(tx) // convert for validation + // checks stuff like inputs >= ouputs + err = blockchain.CheckTransactionSanity(utilTx) + if err != nil { + return hits, err + } + // note that you can't check signatures; this is SPV. + // 0 conf SPV means pretty much nothing. Anyone can say anything. + + // before entering into db, serialize all inputs of the ingested tx + for _, txin := range tx.TxIn { + nOP, err := outPointToBytes(&txin.PreviousOutPoint) + if err != nil { + return hits, err + } + spentOPs = append(spentOPs, nOP) + } + // also generate PKscripts for all addresses (maybe keep storing these?) + for _, adr := range ts.Adrs { + // iterate through all our addresses + aPKscript, err := txscript.PayToAddrScript(adr.PkhAdr) + if err != nil { + return hits, err + } + // iterate through all outputs of this tx + for i, out := range tx.TxOut { + if bytes.Equal(out.PkScript, aPKscript) { // new utxo for us + var newu Utxo + newu.AtHeight = height + newu.KeyIdx = adr.KeyIdx + newu.Value = out.Value + var newop wire.OutPoint + newop.Hash = tx.TxSha() + newop.Index = uint32(i) + newu.Op = newop + b, err := newu.ToBytes() + if err != nil { + return hits, err + } + nUtxoBytes = append(nUtxoBytes, b) + hits++ + break // only one match + } + + } + } + + err = ts.StateDB.Update(func(btx *bolt.Tx) error { + // get all 4 buckets + duf := btx.Bucket(BKTUtxos) + // sta := btx.Bucket(BKTState) + old := btx.Bucket(BKTStxos) + txns := btx.Bucket(BKTTxns) + + // first see if we lose utxos + // iterate through duffel bag and look for matches + // this makes us lose money, which is regrettable, but we need to know. + for _, nOP := range spentOPs { + duf.ForEach(func(k, v []byte) error { + if bytes.Equal(k, nOP) { // matched, we lost utxo + // do all this just to figure out value we lost + x := make([]byte, len(k)+len(v)) + copy(x, k) + copy(x[len(k):], v) + lostTxo, err := UtxoFromBytes(x) + if err != nil { + return err + } + hits++ + // then delete the utxo from duf, save to old + err = duf.Delete(k) + if err != nil { + return err + } + // after deletion, save stxo to old bucket + var st Stxo // generate spent txo + st.Utxo = lostTxo // assign outpoint + st.SpendHeight = height // spent at height + st.SpendTxid = tx.TxSha() // spent by txid + stxb, err := st.ToBytes() // serialize + if err != nil { + return err + } + err = old.Put(k, stxb) // write k:v outpoint:stxo bytes + if err != nil { + return err + } + // store this relevant tx + sha := tx.TxSha() + var buf bytes.Buffer + tx.Serialize(&buf) + err = txns.Put(sha.Bytes(), buf.Bytes()) + if err != nil { + return err + } + + return nil // matched utxo k, won't match another + } + return nil // no match + }) + } // done losing utxos, next gain utxos + // next add all new utxos to db, this is quick as the work is above + for _, ub := range nUtxoBytes { + err = duf.Put(ub[:36], ub[36:]) + if err != nil { + return err + } + } + return nil + }) + return hits, err +}