From 2a904cb69f7ac7803a0834602301eec2bc765aa5 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Thu, 23 May 2019 20:48:08 -0700 Subject: [PATCH] watchtower/wtdb: add Encode/Decode methods to wtclient structs --- channeldb/codec.go | 10 ++ watchtower/wtdb/client_session.go | 73 ++++++++++++ watchtower/wtdb/codec_test.go | 181 +++++++++++++++++++++++++++++- watchtower/wtdb/tower.go | 32 ++++-- 4 files changed, 284 insertions(+), 12 deletions(-) diff --git a/channeldb/codec.go b/channeldb/codec.go index 1da362dd..ec6e165b 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -103,6 +103,11 @@ func WriteElement(w io.Writer, element interface{}) error { return err } + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + case uint64: if err := binary.Write(w, byteOrder, e); err != nil { return err @@ -259,6 +264,11 @@ func ReadElement(r io.Reader, element interface{}) error { } *e = lnwire.NewShortChanIDFromInt(a) + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + case *uint64: if err := binary.Read(r, byteOrder, e); err != nil { return err diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index 5b2d39d7..ab068683 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -2,6 +2,7 @@ package wtdb import ( "errors" + "io" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" @@ -112,8 +113,38 @@ type ClientSessionBody struct { // deposited to if a sweep transaction confirms and the sessions // specifies a reward output. RewardPkScript []byte +} +// Encode writes a ClientSessionBody to the passed io.Writer. +func (s *ClientSessionBody) Encode(w io.Writer) error { + return WriteElements(w, + s.SeqNum, + s.TowerLastApplied, + uint64(s.TowerID), + s.KeyIndex, + s.Policy, + s.RewardPkScript, + ) +} +// Decode reads a ClientSessionBody from the passed io.Reader. +func (s *ClientSessionBody) Decode(r io.Reader) error { + var towerID uint64 + err := ReadElements(r, + &s.SeqNum, + &s.TowerLastApplied, + &towerID, + &s.KeyIndex, + &s.Policy, + &s.RewardPkScript, + ) + if err != nil { + return err + } + + s.TowerID = TowerID(towerID) + + return nil } // BackupID identifies a particular revoked, remote commitment by channel id and @@ -126,6 +157,22 @@ type BackupID struct { CommitHeight uint64 } +// Encode writes the BackupID from the passed io.Writer. +func (b *BackupID) Encode(w io.Writer) error { + return WriteElements(w, + b.ChanID, + b.CommitHeight, + ) +} + +// Decode reads a BackupID from the passed io.Reader. +func (b *BackupID) Decode(r io.Reader) error { + return ReadElements(r, + &b.ChanID, + &b.CommitHeight, + ) +} + // CommittedUpdate holds a state update sent by a client along with its // allocated sequence number and the exact remote commitment the encrypted // justice transaction can rectify. @@ -152,3 +199,29 @@ type CommittedUpdateBody struct { // hint is broadcast. EncryptedBlob []byte } + +// Encode writes the CommittedUpdateBody to the passed io.Writer. +func (u *CommittedUpdateBody) Encode(w io.Writer) error { + err := u.BackupID.Encode(w) + if err != nil { + return err + } + + return WriteElements(w, + u.Hint, + u.EncryptedBlob, + ) +} + +// Decode reads a CommittedUpdateBody from the passed io.Reader. +func (u *CommittedUpdateBody) Decode(r io.Reader) error { + err := u.BackupID.Decode(r) + if err != nil { + return err + } + + return ReadElements(r, + &u.Hint, + &u.EncryptedBlob, + ) +} diff --git a/watchtower/wtdb/codec_test.go b/watchtower/wtdb/codec_test.go index 948ec4ee..21e11c6f 100644 --- a/watchtower/wtdb/codec_test.go +++ b/watchtower/wtdb/codec_test.go @@ -2,14 +2,122 @@ package wtdb_test import ( "bytes" + "encoding/binary" "io" + "math/rand" + "net" "reflect" "testing" "testing/quick" + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" ) +func randPubKey() (*btcec.PublicKey, error) { + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, err + } + + return priv.PubKey(), nil +} + +func randTCP4Addr(r *rand.Rand) (*net.TCPAddr, error) { + var ip [4]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + addrIP := net.IP(ip[:]) + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil +} + +func randTCP6Addr(r *rand.Rand) (*net.TCPAddr, error) { + var ip [16]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + addrIP := net.IP(ip[:]) + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &net.TCPAddr{IP: addrIP, Port: addrPort}, nil +} + +func randV2OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { + var serviceID [tor.V2DecodedLen]byte + if _, err := r.Read(serviceID[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) + onionService += tor.OnionSuffix + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil +} + +func randV3OnionAddr(r *rand.Rand) (*tor.OnionAddr, error) { + var serviceID [tor.V3DecodedLen]byte + if _, err := r.Read(serviceID[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(serviceID[:]) + onionService += tor.OnionSuffix + addrPort := int(binary.BigEndian.Uint16(port[:])) + + return &tor.OnionAddr{OnionService: onionService, Port: addrPort}, nil +} + +func randAddrs(r *rand.Rand) ([]net.Addr, error) { + tcp4Addr, err := randTCP4Addr(r) + if err != nil { + return nil, err + } + + tcp6Addr, err := randTCP6Addr(r) + if err != nil { + return nil, err + } + + v2OnionAddr, err := randV2OnionAddr(r) + if err != nil { + return nil, err + } + + v3OnionAddr, err := randV3OnionAddr(r) + if err != nil { + return nil, err + } + + return []net.Addr{tcp4Addr, tcp6Addr, v2OnionAddr, v3OnionAddr}, nil +} + // dbObject is abstract object support encoding and decoding. type dbObject interface { Encode(io.Writer) error @@ -19,7 +127,9 @@ type dbObject interface { // TestCodec serializes and deserializes wtdb objects in order to test that that // the codec understands all of the required field types. The test also asserts // that decoding an object into another results in an equivalent object. -func TestCodec(t *testing.T) { +func TestCodec(tt *testing.T) { + + var t *testing.T mainScenario := func(obj dbObject) bool { // Ensure encoding the object succeeds. var b bytes.Buffer @@ -35,6 +145,14 @@ func TestCodec(t *testing.T) { obj2 = &wtdb.SessionInfo{} case *wtdb.SessionStateUpdate: obj2 = &wtdb.SessionStateUpdate{} + case *wtdb.ClientSessionBody: + obj2 = &wtdb.ClientSessionBody{} + case *wtdb.CommittedUpdateBody: + obj2 = &wtdb.CommittedUpdateBody{} + case *wtdb.BackupID: + obj2 = &wtdb.BackupID{} + case *wtdb.Tower: + obj2 = &wtdb.Tower{} default: t.Fatalf("unknown type: %T", obj) return false @@ -57,6 +175,29 @@ func TestCodec(t *testing.T) { return true } + customTypeGen := map[string]func([]reflect.Value, *rand.Rand){ + "Tower": func(v []reflect.Value, r *rand.Rand) { + pk, err := randPubKey() + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + return + } + + addrs, err := randAddrs(r) + if err != nil { + t.Fatalf("unable to generate addrs: %v", err) + return + } + + obj := wtdb.Tower{ + IdentityKey: pk, + Addresses: addrs, + } + + v[0] = reflect.ValueOf(obj) + }, + } + tests := []struct { name string scenario interface{} @@ -73,11 +214,45 @@ func TestCodec(t *testing.T) { return mainScenario(&obj) }, }, + { + name: "ClientSessionBody", + scenario: func(obj wtdb.ClientSessionBody) bool { + return mainScenario(&obj) + }, + }, + { + name: "CommittedUpdateBody", + scenario: func(obj wtdb.CommittedUpdateBody) bool { + return mainScenario(&obj) + }, + }, + { + name: "BackupID", + scenario: func(obj wtdb.BackupID) bool { + return mainScenario(&obj) + }, + }, + { + name: "Tower", + scenario: func(obj wtdb.Tower) bool { + return mainScenario(&obj) + }, + }, } for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if err := quick.Check(test.scenario, nil); err != nil { + tt.Run(test.name, func(h *testing.T) { + t = h + + var config *quick.Config + if valueGen, ok := customTypeGen[test.name]; ok { + config = &quick.Config{ + Values: valueGen, + } + } + + err := quick.Check(test.scenario, config) + if err != nil { t.Fatalf("fuzz checks for msg=%s failed: %v", test.name, err) } diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index e4f28781..518da750 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -2,8 +2,8 @@ package wtdb import ( "errors" + "io" "net" - "sync" "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/lnwire" @@ -47,18 +47,15 @@ type Tower struct { // Addresses is a list of possible addresses to reach the tower. Addresses []net.Addr - - mu sync.RWMutex } // AddAddress adds the given address to the tower's in-memory list of addresses. // If the address's string is already present, the Tower will be left // unmodified. Otherwise, the adddress is prepended to the beginning of the // Tower's addresses, on the assumption that it is fresher than the others. +// +// NOTE: This method is NOT safe for concurrent use. func (t *Tower) AddAddress(addr net.Addr) { - t.mu.Lock() - defer t.mu.Unlock() - // Ensure we don't add a duplicate address. addrStr := addr.String() for _, existingAddr := range t.Addresses { @@ -75,10 +72,9 @@ func (t *Tower) AddAddress(addr net.Addr) { // LNAddrs generates a list of lnwire.NetAddress from a Tower instance's // addresses. This can be used to have a client try multiple addresses for the // same Tower. +// +// NOTE: This method is NOT safe for concurrent use. func (t *Tower) LNAddrs() []*lnwire.NetAddress { - t.mu.RLock() - defer t.mu.RUnlock() - addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) for _, addr := range t.Addresses { addrs = append(addrs, &lnwire.NetAddress{ @@ -89,3 +85,21 @@ func (t *Tower) LNAddrs() []*lnwire.NetAddress { return addrs } + +// Encode writes the Tower to the passed io.Writer. The TowerID is not +// serialized, since it acts as the key. +func (t *Tower) Encode(w io.Writer) error { + return WriteElements(w, + t.IdentityKey, + t.Addresses, + ) +} + +// Decode reads a Tower from the passed io.Reader. The TowerID is meant to be +// decoded from the key. +func (t *Tower) Decode(r io.Reader) error { + return ReadElements(r, + &t.IdentityKey, + &t.Addresses, + ) +}