Merge pull request #1458 from cfromknecht/add-addrs-to-codec

channeldb: expand codec to include net.Addr types
This commit is contained in:
Olaoluwa Osuntokun 2018-07-03 22:50:22 -05:00 committed by GitHub
commit a0b2fadea3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 74 deletions

View File

@ -192,7 +192,7 @@ func serializeAddr(w io.Writer, address net.Addr) error {
return encodeTCPAddr(w, addr)
case *tor.OnionAddr:
return encodeOnionAddr(w, addr)
default:
return ErrUnknownAddressType
}
return nil
}

View File

@ -8,34 +8,59 @@ import (
"github.com/lightningnetwork/lnd/tor"
)
type unknownAddrType struct{}
func (_ unknownAddrType) Network() string { return "unknown" }
func (_ unknownAddrType) String() string { return "unknown" }
var addrTests = []struct {
expAddr net.Addr
serErr error
}{
{
expAddr: &net.TCPAddr{
IP: net.ParseIP("192.168.1.1"),
Port: 12345,
},
},
{
expAddr: &net.TCPAddr{
IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"),
Port: 65535,
},
},
{
expAddr: &tor.OnionAddr{
OnionService: "3g2upl4pq6kufc4m.onion",
Port: 9735,
},
},
{
expAddr: &tor.OnionAddr{
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
Port: 80,
},
},
{
expAddr: unknownAddrType{},
serErr: ErrUnknownAddressType,
},
}
// TestAddrSerialization tests that the serialization method used by channeldb
// for net.Addr's works as intended.
func TestAddrSerialization(t *testing.T) {
t.Parallel()
testAddrs := []net.Addr{
&net.TCPAddr{
IP: net.ParseIP("192.168.1.1"),
Port: 12345,
},
&net.TCPAddr{
IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"),
Port: 65535,
},
&tor.OnionAddr{
OnionService: "3g2upl4pq6kufc4m.onion",
Port: 9735,
},
&tor.OnionAddr{
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
Port: 80,
},
}
var b bytes.Buffer
for _, expectedAddr := range testAddrs {
if err := serializeAddr(&b, expectedAddr); err != nil {
t.Fatalf("unable to serialize address: %v", err)
for _, test := range addrTests {
err := serializeAddr(&b, test.expAddr)
if err != test.serErr {
t.Fatalf("unexpected serialization err for addr %v, "+
"want: %v, got %v",
test.expAddr, test.serErr, err)
} else if test.serErr != nil {
continue
}
addr, err := deserializeAddr(&b)
@ -43,9 +68,9 @@ func TestAddrSerialization(t *testing.T) {
t.Fatalf("unable to deserialize address: %v", err)
}
if addr.String() != expectedAddr.String() {
if addr.String() != test.expAddr.String() {
t.Fatalf("expected address %v after serialization, "+
"got %v", addr, expectedAddr)
"got %v", addr, test.expAddr)
}
}
}

View File

@ -916,12 +916,12 @@ type HTLC struct {
// future.
func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
numHtlcs := uint16(len(htlcs))
if err := writeElement(b, numHtlcs); err != nil {
if err := WriteElement(b, numHtlcs); err != nil {
return err
}
for _, htlc := range htlcs {
if err := writeElements(b,
if err := WriteElements(b,
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:],
htlc.HtlcIndex, htlc.LogIndex,
@ -941,7 +941,7 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
// future.
func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
var numHtlcs uint16
if err := readElement(r, &numHtlcs); err != nil {
if err := ReadElement(r, &numHtlcs); err != nil {
return nil, err
}
@ -952,7 +952,7 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
htlcs = make([]HTLC, numHtlcs)
for i := uint16(0); i < numHtlcs; i++ {
if err := readElements(r,
if err := ReadElements(r,
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
&htlcs[i].Incoming, &htlcs[i].OnionBlob,
@ -996,12 +996,12 @@ type LogUpdate struct {
// Encode writes a log update to the provided io.Writer.
func (l *LogUpdate) Encode(w io.Writer) error {
return writeElements(w, l.LogIndex, l.UpdateMsg)
return WriteElements(w, l.LogIndex, l.UpdateMsg)
}
// Decode reads a log update from the provided io.Reader.
func (l *LogUpdate) Decode(r io.Reader) error {
return readElements(r, &l.LogIndex, &l.UpdateMsg)
return ReadElements(r, &l.LogIndex, &l.UpdateMsg)
}
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives
@ -1142,7 +1142,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
}
for _, diff := range diff.LogUpdates {
err := writeElements(w, diff.LogIndex, diff.UpdateMsg)
err := WriteElements(w, diff.LogIndex, diff.UpdateMsg)
if err != nil {
return err
}
@ -1154,7 +1154,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
}
for _, openRef := range diff.OpenedCircuitKeys {
err := writeElements(w, openRef.ChanID, openRef.HtlcID)
err := WriteElements(w, openRef.ChanID, openRef.HtlcID)
if err != nil {
return err
}
@ -1166,7 +1166,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
}
for _, closedRef := range diff.ClosedCircuitKeys {
err := writeElements(w, closedRef.ChanID, closedRef.HtlcID)
err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID)
if err != nil {
return err
}
@ -1198,7 +1198,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
d.LogUpdates = make([]LogUpdate, numUpdates)
for i := 0; i < int(numUpdates); i++ {
err := readElements(r,
err := ReadElements(r,
&d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg,
)
if err != nil {
@ -1213,7 +1213,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs)
for i := 0; i < int(numOpenRefs); i++ {
err := readElements(r,
err := ReadElements(r,
&d.OpenedCircuitKeys[i].ChanID,
&d.OpenedCircuitKeys[i].HtlcID)
if err != nil {
@ -1228,7 +1228,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs)
for i := 0; i < int(numClosedRefs); i++ {
err := readElements(r,
err := ReadElements(r,
&d.ClosedCircuitKeys[i].ChanID,
&d.ClosedCircuitKeys[i].HtlcID)
if err != nil {
@ -1957,7 +1957,7 @@ func putChannelCloseSummary(tx *bolt.Tx, chanID []byte,
}
func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
err := writeElements(w,
err := WriteElements(w,
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance,
cs.TimeLockedBalance, cs.CloseType, cs.IsPending,
@ -1972,7 +1972,7 @@ func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
return nil
}
if err := writeElements(w, cs.RemoteCurrentRevocation); err != nil {
if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil {
return err
}
@ -1987,7 +1987,7 @@ func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
return nil
}
return writeElements(w, cs.RemoteNextRevocation)
return WriteElements(w, cs.RemoteNextRevocation)
}
func fetchChannelCloseSummary(tx *bolt.Tx,
@ -2010,7 +2010,7 @@ func fetchChannelCloseSummary(tx *bolt.Tx,
func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
c := &ChannelCloseSummary{}
err := readElements(r,
err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending,
@ -2021,7 +2021,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
// We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields.
err = readElements(r, &c.RemoteCurrentRevocation)
err = ReadElements(r, &c.RemoteCurrentRevocation)
switch {
case err == io.EOF:
return c, nil
@ -2042,7 +2042,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
// funding locked message, then this can be nil. As a result, we'll use
// the same technique to read the field, only if there's still data
// left in the buffer.
err = readElements(r, &c.RemoteNextRevocation)
err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil && err != io.EOF {
// If we got a non-eof error, then we know there's an actually
// issue. Otherwise, it may have been the case that this
@ -2054,7 +2054,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
}
func writeChanConfig(b io.Writer, c *ChannelConfig) error {
return writeElements(b,
return WriteElements(b,
c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC,
c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey,
c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint,
@ -2064,7 +2064,7 @@ func writeChanConfig(b io.Writer, c *ChannelConfig) error {
func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
var w bytes.Buffer
if err := writeElements(&w,
if err := WriteElements(&w,
channel.ChanType, channel.ChainHash, channel.FundingOutpoint,
channel.ShortChannelID, channel.IsPending, channel.IsInitiator,
channel.ChanStatus, channel.FundingBroadcastHeight,
@ -2077,7 +2077,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
// For single funder channels that we initiated, write the funding txn.
if channel.ChanType == SingleFunder && channel.IsInitiator {
if err := writeElement(&w, channel.FundingTxn); err != nil {
if err := WriteElement(&w, channel.FundingTxn); err != nil {
return err
}
}
@ -2093,7 +2093,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
}
func serializeChanCommit(w io.Writer, c *ChannelCommitment) error {
if err := writeElements(w,
if err := WriteElements(w,
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
@ -2135,7 +2135,7 @@ func putChanCommitments(chanBucket *bolt.Bucket, channel *OpenChannel) error {
func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error {
var b bytes.Buffer
err := writeElements(
err := WriteElements(
&b, channel.RemoteCurrentRevocation, channel.RevocationProducer,
channel.RevocationStore,
)
@ -2148,7 +2148,7 @@ func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error
// If the next revocation is present, which is only the case after the
// FundingLocked message has been sent, then we'll write it to disk.
if channel.RemoteNextRevocation != nil {
err = writeElements(&b, channel.RemoteNextRevocation)
err = WriteElements(&b, channel.RemoteNextRevocation)
if err != nil {
return err
}
@ -2158,7 +2158,7 @@ func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error
}
func readChanConfig(b io.Reader, c *ChannelConfig) error {
return readElements(b,
return ReadElements(b,
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
&c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay,
&c.MultiSigKey, &c.RevocationBasePoint,
@ -2174,7 +2174,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
}
r := bytes.NewReader(infoBytes)
if err := readElements(r,
if err := ReadElements(r,
&channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint,
&channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator,
&channel.ChanStatus, &channel.FundingBroadcastHeight,
@ -2187,7 +2187,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
// For single funder channels that we initiated, read the funding txn.
if channel.ChanType == SingleFunder && channel.IsInitiator {
if err := readElement(r, &channel.FundingTxn); err != nil {
if err := ReadElement(r, &channel.FundingTxn); err != nil {
return err
}
}
@ -2207,7 +2207,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) {
var c ChannelCommitment
err := readElements(r,
err := ReadElements(r,
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
@ -2263,7 +2263,7 @@ func fetchChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) err
}
r := bytes.NewReader(revBytes)
err := readElements(
err := ReadElements(
r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer,
&channel.RevocationStore,
)
@ -2279,7 +2279,7 @@ func fetchChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) err
// Otherwise we'll read the next revocation for the remote party which
// is always the last item within the buffer.
return readElements(r, &channel.RemoteNextRevocation)
return ReadElements(r, &channel.RemoteNextRevocation)
}
func deleteOpenChannel(chanBucket *bolt.Bucket, chanPointBytes []byte) error {

View File

@ -4,6 +4,7 @@ import (
"encoding/binary"
"fmt"
"io"
"net"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire"
@ -43,11 +44,24 @@ func readOutpoint(r io.Reader, o *wire.OutPoint) error {
return nil
}
// writeElement is a one-stop shop to write the big endian representation of
// UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data.
func writeElement(w io.Writer, element interface{}) error {
func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) {
case keychain.KeyDescriptor:
if err := binary.Write(w, byteOrder, e.Family); err != nil {
@ -61,7 +75,7 @@ func writeElement(w io.Writer, element interface{}) error {
if err := binary.Write(w, byteOrder, true); err != nil {
}
return writeElement(w, e.PubKey)
return WriteElement(w, e.PubKey)
}
return binary.Write(w, byteOrder, false)
@ -163,18 +177,34 @@ func writeElement(w io.Writer, element interface{}) error {
return err
}
case net.Addr:
if err := serializeAddr(w, e); err != nil {
return err
}
case []net.Addr:
if err := WriteElement(w, uint32(len(e))); err != nil {
return err
}
for _, addr := range e {
if err := serializeAddr(w, addr); err != nil {
return err
}
}
default:
return fmt.Errorf("Unknown type in writeElement: %T", e)
return UnknownElementType{"WriteElement", e}
}
return nil
}
// writeElements is writes each element in the elements slice to the passed
// io.Writer using writeElement.
func writeElements(w io.Writer, elements ...interface{}) error {
// WriteElements is writes each element in the elements slice to the passed
// io.Writer using WriteElement.
func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements {
err := writeElement(w, element)
err := WriteElement(w, element)
if err != nil {
return err
}
@ -182,9 +212,9 @@ func writeElements(w io.Writer, elements ...interface{}) error {
return nil
}
// readElement is a one-stop utility function to deserialize any datastructure
// ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database.
func readElement(r io.Reader, element interface{}) error {
func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) {
case *keychain.KeyDescriptor:
if err := binary.Read(r, byteOrder, &e.Family); err != nil {
@ -200,7 +230,7 @@ func readElement(r io.Reader, element interface{}) error {
}
if hasPubKey {
return readElement(r, &e.PubKey)
return ReadElement(r, &e.PubKey)
}
case *ChannelType:
@ -342,19 +372,41 @@ func readElement(r io.Reader, element interface{}) error {
return err
}
case *net.Addr:
addr, err := deserializeAddr(r)
if err != nil {
return err
}
*e = addr
case *[]net.Addr:
var numAddrs uint32
if err := ReadElement(r, &numAddrs); err != nil {
return err
}
*e = make([]net.Addr, numAddrs)
for i := uint32(0); i < numAddrs; i++ {
addr, err := deserializeAddr(r)
if err != nil {
return err
}
(*e)[i] = addr
}
default:
return fmt.Errorf("Unknown type in readElement: %T", e)
return UnknownElementType{"ReadElement", e}
}
return nil
}
// readElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the readElement
// ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the ReadElement
// function.
func readElements(r io.Reader, elements ...interface{}) error {
func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements {
err := readElement(r, element)
err := ReadElement(r, element)
if err != nil {
return err
}

View File

@ -83,7 +83,7 @@ type ForwardingEvent struct {
// io.Writer, using the expected DB format. Note that the timestamp isn't
// serialized as this will be the key value within the bucket.
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
return writeElements(
return WriteElements(
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
)
}
@ -93,7 +93,7 @@ func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
// won't be decoded, as the caller is expected to set this due to the bucket
// structure of the forwarding log.
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
return readElements(
return ReadElements(
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
)
}