diff --git a/channeldb/addr.go b/channeldb/addr.go index 4d2f578f..0d40099e 100644 --- a/channeldb/addr.go +++ b/channeldb/addr.go @@ -192,7 +192,7 @@ func serializeAddr(w io.Writer, address net.Addr) error { return encodeTCPAddr(w, addr) case *tor.OnionAddr: return encodeOnionAddr(w, addr) + default: + return ErrUnknownAddressType } - - return nil } diff --git a/channeldb/addr_test.go b/channeldb/addr_test.go index 2de179d6..d093bb3d 100644 --- a/channeldb/addr_test.go +++ b/channeldb/addr_test.go @@ -8,34 +8,59 @@ import ( "github.com/lightningnetwork/lnd/tor" ) +type unknownAddrType struct{} + +func (_ unknownAddrType) Network() string { return "unknown" } +func (_ unknownAddrType) String() string { return "unknown" } + +var addrTests = []struct { + expAddr net.Addr + serErr error +}{ + { + expAddr: &net.TCPAddr{ + IP: net.ParseIP("192.168.1.1"), + Port: 12345, + }, + }, + { + expAddr: &net.TCPAddr{ + IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"), + Port: 65535, + }, + }, + { + expAddr: &tor.OnionAddr{ + OnionService: "3g2upl4pq6kufc4m.onion", + Port: 9735, + }, + }, + { + expAddr: &tor.OnionAddr{ + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion", + Port: 80, + }, + }, + { + expAddr: unknownAddrType{}, + serErr: ErrUnknownAddressType, + }, +} + // TestAddrSerialization tests that the serialization method used by channeldb // for net.Addr's works as intended. func TestAddrSerialization(t *testing.T) { t.Parallel() - testAddrs := []net.Addr{ - &net.TCPAddr{ - IP: net.ParseIP("192.168.1.1"), - Port: 12345, - }, - &net.TCPAddr{ - IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"), - Port: 65535, - }, - &tor.OnionAddr{ - OnionService: "3g2upl4pq6kufc4m.onion", - Port: 9735, - }, - &tor.OnionAddr{ - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion", - Port: 80, - }, - } - var b bytes.Buffer - for _, expectedAddr := range testAddrs { - if err := serializeAddr(&b, expectedAddr); err != nil { - t.Fatalf("unable to serialize address: %v", err) + for _, test := range addrTests { + err := serializeAddr(&b, test.expAddr) + if err != test.serErr { + t.Fatalf("unexpected serialization err for addr %v, "+ + "want: %v, got %v", + test.expAddr, test.serErr, err) + } else if test.serErr != nil { + continue } addr, err := deserializeAddr(&b) @@ -43,9 +68,9 @@ func TestAddrSerialization(t *testing.T) { t.Fatalf("unable to deserialize address: %v", err) } - if addr.String() != expectedAddr.String() { + if addr.String() != test.expAddr.String() { t.Fatalf("expected address %v after serialization, "+ - "got %v", addr, expectedAddr) + "got %v", addr, test.expAddr) } } } diff --git a/channeldb/channel.go b/channeldb/channel.go index 9d7d01ba..0c6245d2 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -916,12 +916,12 @@ type HTLC struct { // future. func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { numHtlcs := uint16(len(htlcs)) - if err := writeElement(b, numHtlcs); err != nil { + if err := WriteElement(b, numHtlcs); err != nil { return err } for _, htlc := range htlcs { - if err := writeElements(b, + if err := WriteElements(b, htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:], htlc.HtlcIndex, htlc.LogIndex, @@ -941,7 +941,7 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { // future. func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { var numHtlcs uint16 - if err := readElement(r, &numHtlcs); err != nil { + if err := ReadElement(r, &numHtlcs); err != nil { return nil, err } @@ -952,7 +952,7 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { htlcs = make([]HTLC, numHtlcs) for i := uint16(0); i < numHtlcs; i++ { - if err := readElements(r, + if err := ReadElements(r, &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, &htlcs[i].Incoming, &htlcs[i].OnionBlob, @@ -996,12 +996,12 @@ type LogUpdate struct { // Encode writes a log update to the provided io.Writer. func (l *LogUpdate) Encode(w io.Writer) error { - return writeElements(w, l.LogIndex, l.UpdateMsg) + return WriteElements(w, l.LogIndex, l.UpdateMsg) } // Decode reads a log update from the provided io.Reader. func (l *LogUpdate) Decode(r io.Reader) error { - return readElements(r, &l.LogIndex, &l.UpdateMsg) + return ReadElements(r, &l.LogIndex, &l.UpdateMsg) } // CircuitKey is used by a channel to uniquely identify the HTLCs it receives @@ -1142,7 +1142,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { } for _, diff := range diff.LogUpdates { - err := writeElements(w, diff.LogIndex, diff.UpdateMsg) + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) if err != nil { return err } @@ -1154,7 +1154,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { } for _, openRef := range diff.OpenedCircuitKeys { - err := writeElements(w, openRef.ChanID, openRef.HtlcID) + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) if err != nil { return err } @@ -1166,7 +1166,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { } for _, closedRef := range diff.ClosedCircuitKeys { - err := writeElements(w, closedRef.ChanID, closedRef.HtlcID) + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) if err != nil { return err } @@ -1198,7 +1198,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { d.LogUpdates = make([]LogUpdate, numUpdates) for i := 0; i < int(numUpdates); i++ { - err := readElements(r, + err := ReadElements(r, &d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg, ) if err != nil { @@ -1213,7 +1213,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) for i := 0; i < int(numOpenRefs); i++ { - err := readElements(r, + err := ReadElements(r, &d.OpenedCircuitKeys[i].ChanID, &d.OpenedCircuitKeys[i].HtlcID) if err != nil { @@ -1228,7 +1228,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) for i := 0; i < int(numClosedRefs); i++ { - err := readElements(r, + err := ReadElements(r, &d.ClosedCircuitKeys[i].ChanID, &d.ClosedCircuitKeys[i].HtlcID) if err != nil { @@ -1957,7 +1957,7 @@ func putChannelCloseSummary(tx *bolt.Tx, chanID []byte, } func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { - err := writeElements(w, + err := WriteElements(w, cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, cs.TimeLockedBalance, cs.CloseType, cs.IsPending, @@ -1972,7 +1972,7 @@ func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { return nil } - if err := writeElements(w, cs.RemoteCurrentRevocation); err != nil { + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { return err } @@ -1987,7 +1987,7 @@ func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { return nil } - return writeElements(w, cs.RemoteNextRevocation) + return WriteElements(w, cs.RemoteNextRevocation) } func fetchChannelCloseSummary(tx *bolt.Tx, @@ -2010,7 +2010,7 @@ func fetchChannelCloseSummary(tx *bolt.Tx, func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { c := &ChannelCloseSummary{} - err := readElements(r, + err := ReadElements(r, &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, &c.TimeLockedBalance, &c.CloseType, &c.IsPending, @@ -2021,7 +2021,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { // We'll now check to see if the channel close summary was encoded with // any of the additional optional fields. - err = readElements(r, &c.RemoteCurrentRevocation) + err = ReadElements(r, &c.RemoteCurrentRevocation) switch { case err == io.EOF: return c, nil @@ -2042,7 +2042,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { // funding locked message, then this can be nil. As a result, we'll use // the same technique to read the field, only if there's still data // left in the buffer. - err = readElements(r, &c.RemoteNextRevocation) + err = ReadElements(r, &c.RemoteNextRevocation) if err != nil && err != io.EOF { // If we got a non-eof error, then we know there's an actually // issue. Otherwise, it may have been the case that this @@ -2054,7 +2054,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { } func writeChanConfig(b io.Writer, c *ChannelConfig) error { - return writeElements(b, + return WriteElements(b, c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, @@ -2064,7 +2064,7 @@ func writeChanConfig(b io.Writer, c *ChannelConfig) error { func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { var w bytes.Buffer - if err := writeElements(&w, + if err := WriteElements(&w, channel.ChanType, channel.ChainHash, channel.FundingOutpoint, channel.ShortChannelID, channel.IsPending, channel.IsInitiator, channel.ChanStatus, channel.FundingBroadcastHeight, @@ -2077,7 +2077,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { // For single funder channels that we initiated, write the funding txn. if channel.ChanType == SingleFunder && channel.IsInitiator { - if err := writeElement(&w, channel.FundingTxn); err != nil { + if err := WriteElement(&w, channel.FundingTxn); err != nil { return err } } @@ -2093,7 +2093,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { } func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { - if err := writeElements(w, + if err := WriteElements(w, c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, @@ -2135,7 +2135,7 @@ func putChanCommitments(chanBucket *bolt.Bucket, channel *OpenChannel) error { func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error { var b bytes.Buffer - err := writeElements( + err := WriteElements( &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, channel.RevocationStore, ) @@ -2148,7 +2148,7 @@ func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error // If the next revocation is present, which is only the case after the // FundingLocked message has been sent, then we'll write it to disk. if channel.RemoteNextRevocation != nil { - err = writeElements(&b, channel.RemoteNextRevocation) + err = WriteElements(&b, channel.RemoteNextRevocation) if err != nil { return err } @@ -2158,7 +2158,7 @@ func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error } func readChanConfig(b io.Reader, c *ChannelConfig) error { - return readElements(b, + return ReadElements(b, &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, &c.MultiSigKey, &c.RevocationBasePoint, @@ -2174,7 +2174,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { } r := bytes.NewReader(infoBytes) - if err := readElements(r, + if err := ReadElements(r, &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, &channel.ChanStatus, &channel.FundingBroadcastHeight, @@ -2187,7 +2187,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { // For single funder channels that we initiated, read the funding txn. if channel.ChanType == SingleFunder && channel.IsInitiator { - if err := readElement(r, &channel.FundingTxn); err != nil { + if err := ReadElement(r, &channel.FundingTxn); err != nil { return err } } @@ -2207,7 +2207,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { var c ChannelCommitment - err := readElements(r, + err := ReadElements(r, &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, @@ -2263,7 +2263,7 @@ func fetchChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) err } r := bytes.NewReader(revBytes) - err := readElements( + err := ReadElements( r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, &channel.RevocationStore, ) @@ -2279,7 +2279,7 @@ func fetchChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) err // Otherwise we'll read the next revocation for the remote party which // is always the last item within the buffer. - return readElements(r, &channel.RemoteNextRevocation) + return ReadElements(r, &channel.RemoteNextRevocation) } func deleteOpenChannel(chanBucket *bolt.Bucket, chanPointBytes []byte) error { diff --git a/channeldb/codec.go b/channeldb/codec.go index 086c2533..149be180 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "fmt" "io" + "net" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" @@ -43,11 +44,24 @@ func readOutpoint(r io.Reader, o *wire.OutPoint) error { return nil } -// writeElement is a one-stop shop to write the big endian representation of +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of // any element which is to be serialized for storage on disk. The passed // io.Writer should be backed by an appropriately sized byte slice, or be able // to dynamically expand to accommodate additional data. -func writeElement(w io.Writer, element interface{}) error { +func WriteElement(w io.Writer, element interface{}) error { switch e := element.(type) { case keychain.KeyDescriptor: if err := binary.Write(w, byteOrder, e.Family); err != nil { @@ -61,7 +75,7 @@ func writeElement(w io.Writer, element interface{}) error { if err := binary.Write(w, byteOrder, true); err != nil { } - return writeElement(w, e.PubKey) + return WriteElement(w, e.PubKey) } return binary.Write(w, byteOrder, false) @@ -163,18 +177,34 @@ func writeElement(w io.Writer, element interface{}) error { return err } + case net.Addr: + if err := serializeAddr(w, e); err != nil { + return err + } + + case []net.Addr: + if err := WriteElement(w, uint32(len(e))); err != nil { + return err + } + + for _, addr := range e { + if err := serializeAddr(w, addr); err != nil { + return err + } + } + default: - return fmt.Errorf("Unknown type in writeElement: %T", e) + return UnknownElementType{"WriteElement", e} } return nil } -// writeElements is writes each element in the elements slice to the passed -// io.Writer using writeElement. -func writeElements(w io.Writer, elements ...interface{}) error { +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { for _, element := range elements { - err := writeElement(w, element) + err := WriteElement(w, element) if err != nil { return err } @@ -182,9 +212,9 @@ func writeElements(w io.Writer, elements ...interface{}) error { return nil } -// readElement is a one-stop utility function to deserialize any datastructure +// ReadElement is a one-stop utility function to deserialize any datastructure // encoded using the serialization format of the database. -func readElement(r io.Reader, element interface{}) error { +func ReadElement(r io.Reader, element interface{}) error { switch e := element.(type) { case *keychain.KeyDescriptor: if err := binary.Read(r, byteOrder, &e.Family); err != nil { @@ -200,7 +230,7 @@ func readElement(r io.Reader, element interface{}) error { } if hasPubKey { - return readElement(r, &e.PubKey) + return ReadElement(r, &e.PubKey) } case *ChannelType: @@ -342,19 +372,41 @@ func readElement(r io.Reader, element interface{}) error { return err } + case *net.Addr: + addr, err := deserializeAddr(r) + if err != nil { + return err + } + *e = addr + + case *[]net.Addr: + var numAddrs uint32 + if err := ReadElement(r, &numAddrs); err != nil { + return err + } + + *e = make([]net.Addr, numAddrs) + for i := uint32(0); i < numAddrs; i++ { + addr, err := deserializeAddr(r) + if err != nil { + return err + } + (*e)[i] = addr + } + default: - return fmt.Errorf("Unknown type in readElement: %T", e) + return UnknownElementType{"ReadElement", e} } return nil } -// readElements deserializes a variable number of elements into the passed -// io.Reader, with each element being deserialized according to the readElement +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement // function. -func readElements(r io.Reader, elements ...interface{}) error { +func ReadElements(r io.Reader, elements ...interface{}) error { for _, element := range elements { - err := readElement(r, element) + err := ReadElement(r, element) if err != nil { return err } diff --git a/channeldb/forwarding_log.go b/channeldb/forwarding_log.go index 3f230fe4..b444e32c 100644 --- a/channeldb/forwarding_log.go +++ b/channeldb/forwarding_log.go @@ -83,7 +83,7 @@ type ForwardingEvent struct { // io.Writer, using the expected DB format. Note that the timestamp isn't // serialized as this will be the key value within the bucket. func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { - return writeElements( + return WriteElements( w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, ) } @@ -93,7 +93,7 @@ func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { // won't be decoded, as the caller is expected to set this due to the bucket // structure of the forwarding log. func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { - return readElements( + return ReadElements( r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, ) }