Merge pull request #1458 from cfromknecht/add-addrs-to-codec

channeldb: expand codec to include net.Addr types
This commit is contained in:
Olaoluwa Osuntokun 2018-07-03 22:50:22 -05:00 committed by GitHub
commit a0b2fadea3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 151 additions and 74 deletions

@ -192,7 +192,7 @@ func serializeAddr(w io.Writer, address net.Addr) error {
return encodeTCPAddr(w, addr) return encodeTCPAddr(w, addr)
case *tor.OnionAddr: case *tor.OnionAddr:
return encodeOnionAddr(w, addr) return encodeOnionAddr(w, addr)
default:
return ErrUnknownAddressType
} }
return nil
} }

@ -8,34 +8,59 @@ import (
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
) )
type unknownAddrType struct{}
func (_ unknownAddrType) Network() string { return "unknown" }
func (_ unknownAddrType) String() string { return "unknown" }
var addrTests = []struct {
expAddr net.Addr
serErr error
}{
{
expAddr: &net.TCPAddr{
IP: net.ParseIP("192.168.1.1"),
Port: 12345,
},
},
{
expAddr: &net.TCPAddr{
IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"),
Port: 65535,
},
},
{
expAddr: &tor.OnionAddr{
OnionService: "3g2upl4pq6kufc4m.onion",
Port: 9735,
},
},
{
expAddr: &tor.OnionAddr{
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
Port: 80,
},
},
{
expAddr: unknownAddrType{},
serErr: ErrUnknownAddressType,
},
}
// TestAddrSerialization tests that the serialization method used by channeldb // TestAddrSerialization tests that the serialization method used by channeldb
// for net.Addr's works as intended. // for net.Addr's works as intended.
func TestAddrSerialization(t *testing.T) { func TestAddrSerialization(t *testing.T) {
t.Parallel() t.Parallel()
testAddrs := []net.Addr{
&net.TCPAddr{
IP: net.ParseIP("192.168.1.1"),
Port: 12345,
},
&net.TCPAddr{
IP: net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329"),
Port: 65535,
},
&tor.OnionAddr{
OnionService: "3g2upl4pq6kufc4m.onion",
Port: 9735,
},
&tor.OnionAddr{
OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion",
Port: 80,
},
}
var b bytes.Buffer var b bytes.Buffer
for _, expectedAddr := range testAddrs { for _, test := range addrTests {
if err := serializeAddr(&b, expectedAddr); err != nil { err := serializeAddr(&b, test.expAddr)
t.Fatalf("unable to serialize address: %v", err) if err != test.serErr {
t.Fatalf("unexpected serialization err for addr %v, "+
"want: %v, got %v",
test.expAddr, test.serErr, err)
} else if test.serErr != nil {
continue
} }
addr, err := deserializeAddr(&b) addr, err := deserializeAddr(&b)
@ -43,9 +68,9 @@ func TestAddrSerialization(t *testing.T) {
t.Fatalf("unable to deserialize address: %v", err) t.Fatalf("unable to deserialize address: %v", err)
} }
if addr.String() != expectedAddr.String() { if addr.String() != test.expAddr.String() {
t.Fatalf("expected address %v after serialization, "+ t.Fatalf("expected address %v after serialization, "+
"got %v", addr, expectedAddr) "got %v", addr, test.expAddr)
} }
} }
} }

@ -916,12 +916,12 @@ type HTLC struct {
// future. // future.
func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
numHtlcs := uint16(len(htlcs)) numHtlcs := uint16(len(htlcs))
if err := writeElement(b, numHtlcs); err != nil { if err := WriteElement(b, numHtlcs); err != nil {
return err return err
} }
for _, htlc := range htlcs { for _, htlc := range htlcs {
if err := writeElements(b, if err := WriteElements(b,
htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout,
htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:], htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:],
htlc.HtlcIndex, htlc.LogIndex, htlc.HtlcIndex, htlc.LogIndex,
@ -941,7 +941,7 @@ func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error {
// future. // future.
func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
var numHtlcs uint16 var numHtlcs uint16
if err := readElement(r, &numHtlcs); err != nil { if err := ReadElement(r, &numHtlcs); err != nil {
return nil, err return nil, err
} }
@ -952,7 +952,7 @@ func DeserializeHtlcs(r io.Reader) ([]HTLC, error) {
htlcs = make([]HTLC, numHtlcs) htlcs = make([]HTLC, numHtlcs)
for i := uint16(0); i < numHtlcs; i++ { for i := uint16(0); i < numHtlcs; i++ {
if err := readElements(r, if err := ReadElements(r,
&htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt,
&htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex,
&htlcs[i].Incoming, &htlcs[i].OnionBlob, &htlcs[i].Incoming, &htlcs[i].OnionBlob,
@ -996,12 +996,12 @@ type LogUpdate struct {
// Encode writes a log update to the provided io.Writer. // Encode writes a log update to the provided io.Writer.
func (l *LogUpdate) Encode(w io.Writer) error { func (l *LogUpdate) Encode(w io.Writer) error {
return writeElements(w, l.LogIndex, l.UpdateMsg) return WriteElements(w, l.LogIndex, l.UpdateMsg)
} }
// Decode reads a log update from the provided io.Reader. // Decode reads a log update from the provided io.Reader.
func (l *LogUpdate) Decode(r io.Reader) error { func (l *LogUpdate) Decode(r io.Reader) error {
return readElements(r, &l.LogIndex, &l.UpdateMsg) return ReadElements(r, &l.LogIndex, &l.UpdateMsg)
} }
// CircuitKey is used by a channel to uniquely identify the HTLCs it receives // CircuitKey is used by a channel to uniquely identify the HTLCs it receives
@ -1142,7 +1142,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
} }
for _, diff := range diff.LogUpdates { for _, diff := range diff.LogUpdates {
err := writeElements(w, diff.LogIndex, diff.UpdateMsg) err := WriteElements(w, diff.LogIndex, diff.UpdateMsg)
if err != nil { if err != nil {
return err return err
} }
@ -1154,7 +1154,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
} }
for _, openRef := range diff.OpenedCircuitKeys { for _, openRef := range diff.OpenedCircuitKeys {
err := writeElements(w, openRef.ChanID, openRef.HtlcID) err := WriteElements(w, openRef.ChanID, openRef.HtlcID)
if err != nil { if err != nil {
return err return err
} }
@ -1166,7 +1166,7 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
} }
for _, closedRef := range diff.ClosedCircuitKeys { for _, closedRef := range diff.ClosedCircuitKeys {
err := writeElements(w, closedRef.ChanID, closedRef.HtlcID) err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID)
if err != nil { if err != nil {
return err return err
} }
@ -1198,7 +1198,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
d.LogUpdates = make([]LogUpdate, numUpdates) d.LogUpdates = make([]LogUpdate, numUpdates)
for i := 0; i < int(numUpdates); i++ { for i := 0; i < int(numUpdates); i++ {
err := readElements(r, err := ReadElements(r,
&d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg, &d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg,
) )
if err != nil { if err != nil {
@ -1213,7 +1213,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs)
for i := 0; i < int(numOpenRefs); i++ { for i := 0; i < int(numOpenRefs); i++ {
err := readElements(r, err := ReadElements(r,
&d.OpenedCircuitKeys[i].ChanID, &d.OpenedCircuitKeys[i].ChanID,
&d.OpenedCircuitKeys[i].HtlcID) &d.OpenedCircuitKeys[i].HtlcID)
if err != nil { if err != nil {
@ -1228,7 +1228,7 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs)
for i := 0; i < int(numClosedRefs); i++ { for i := 0; i < int(numClosedRefs); i++ {
err := readElements(r, err := ReadElements(r,
&d.ClosedCircuitKeys[i].ChanID, &d.ClosedCircuitKeys[i].ChanID,
&d.ClosedCircuitKeys[i].HtlcID) &d.ClosedCircuitKeys[i].HtlcID)
if err != nil { if err != nil {
@ -1957,7 +1957,7 @@ func putChannelCloseSummary(tx *bolt.Tx, chanID []byte,
} }
func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
err := writeElements(w, err := WriteElements(w,
cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID,
cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance,
cs.TimeLockedBalance, cs.CloseType, cs.IsPending, cs.TimeLockedBalance, cs.CloseType, cs.IsPending,
@ -1972,7 +1972,7 @@ func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
return nil return nil
} }
if err := writeElements(w, cs.RemoteCurrentRevocation); err != nil { if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil {
return err return err
} }
@ -1987,7 +1987,7 @@ func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error {
return nil return nil
} }
return writeElements(w, cs.RemoteNextRevocation) return WriteElements(w, cs.RemoteNextRevocation)
} }
func fetchChannelCloseSummary(tx *bolt.Tx, func fetchChannelCloseSummary(tx *bolt.Tx,
@ -2010,7 +2010,7 @@ func fetchChannelCloseSummary(tx *bolt.Tx,
func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
c := &ChannelCloseSummary{} c := &ChannelCloseSummary{}
err := readElements(r, err := ReadElements(r,
&c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID,
&c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance,
&c.TimeLockedBalance, &c.CloseType, &c.IsPending, &c.TimeLockedBalance, &c.CloseType, &c.IsPending,
@ -2021,7 +2021,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
// We'll now check to see if the channel close summary was encoded with // We'll now check to see if the channel close summary was encoded with
// any of the additional optional fields. // any of the additional optional fields.
err = readElements(r, &c.RemoteCurrentRevocation) err = ReadElements(r, &c.RemoteCurrentRevocation)
switch { switch {
case err == io.EOF: case err == io.EOF:
return c, nil return c, nil
@ -2042,7 +2042,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
// funding locked message, then this can be nil. As a result, we'll use // funding locked message, then this can be nil. As a result, we'll use
// the same technique to read the field, only if there's still data // the same technique to read the field, only if there's still data
// left in the buffer. // left in the buffer.
err = readElements(r, &c.RemoteNextRevocation) err = ReadElements(r, &c.RemoteNextRevocation)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
// If we got a non-eof error, then we know there's an actually // If we got a non-eof error, then we know there's an actually
// issue. Otherwise, it may have been the case that this // issue. Otherwise, it may have been the case that this
@ -2054,7 +2054,7 @@ func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) {
} }
func writeChanConfig(b io.Writer, c *ChannelConfig) error { func writeChanConfig(b io.Writer, c *ChannelConfig) error {
return writeElements(b, return WriteElements(b,
c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC,
c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey,
c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint,
@ -2064,7 +2064,7 @@ func writeChanConfig(b io.Writer, c *ChannelConfig) error {
func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
var w bytes.Buffer var w bytes.Buffer
if err := writeElements(&w, if err := WriteElements(&w,
channel.ChanType, channel.ChainHash, channel.FundingOutpoint, channel.ChanType, channel.ChainHash, channel.FundingOutpoint,
channel.ShortChannelID, channel.IsPending, channel.IsInitiator, channel.ShortChannelID, channel.IsPending, channel.IsInitiator,
channel.ChanStatus, channel.FundingBroadcastHeight, channel.ChanStatus, channel.FundingBroadcastHeight,
@ -2077,7 +2077,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
// For single funder channels that we initiated, write the funding txn. // For single funder channels that we initiated, write the funding txn.
if channel.ChanType == SingleFunder && channel.IsInitiator { if channel.ChanType == SingleFunder && channel.IsInitiator {
if err := writeElement(&w, channel.FundingTxn); err != nil { if err := WriteElement(&w, channel.FundingTxn); err != nil {
return err return err
} }
} }
@ -2093,7 +2093,7 @@ func putChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
} }
func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { func serializeChanCommit(w io.Writer, c *ChannelCommitment) error {
if err := writeElements(w, if err := WriteElements(w,
c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex,
c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance,
c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx,
@ -2135,7 +2135,7 @@ func putChanCommitments(chanBucket *bolt.Bucket, channel *OpenChannel) error {
func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error { func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error {
var b bytes.Buffer var b bytes.Buffer
err := writeElements( err := WriteElements(
&b, channel.RemoteCurrentRevocation, channel.RevocationProducer, &b, channel.RemoteCurrentRevocation, channel.RevocationProducer,
channel.RevocationStore, channel.RevocationStore,
) )
@ -2148,7 +2148,7 @@ func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error
// If the next revocation is present, which is only the case after the // If the next revocation is present, which is only the case after the
// FundingLocked message has been sent, then we'll write it to disk. // FundingLocked message has been sent, then we'll write it to disk.
if channel.RemoteNextRevocation != nil { if channel.RemoteNextRevocation != nil {
err = writeElements(&b, channel.RemoteNextRevocation) err = WriteElements(&b, channel.RemoteNextRevocation)
if err != nil { if err != nil {
return err return err
} }
@ -2158,7 +2158,7 @@ func putChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) error
} }
func readChanConfig(b io.Reader, c *ChannelConfig) error { func readChanConfig(b io.Reader, c *ChannelConfig) error {
return readElements(b, return ReadElements(b,
&c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve,
&c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay,
&c.MultiSigKey, &c.RevocationBasePoint, &c.MultiSigKey, &c.RevocationBasePoint,
@ -2174,7 +2174,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
} }
r := bytes.NewReader(infoBytes) r := bytes.NewReader(infoBytes)
if err := readElements(r, if err := ReadElements(r,
&channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint,
&channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator,
&channel.ChanStatus, &channel.FundingBroadcastHeight, &channel.ChanStatus, &channel.FundingBroadcastHeight,
@ -2187,7 +2187,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
// For single funder channels that we initiated, read the funding txn. // For single funder channels that we initiated, read the funding txn.
if channel.ChanType == SingleFunder && channel.IsInitiator { if channel.ChanType == SingleFunder && channel.IsInitiator {
if err := readElement(r, &channel.FundingTxn); err != nil { if err := ReadElement(r, &channel.FundingTxn); err != nil {
return err return err
} }
} }
@ -2207,7 +2207,7 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error {
func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) {
var c ChannelCommitment var c ChannelCommitment
err := readElements(r, err := ReadElements(r,
&c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex,
&c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance,
&c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig,
@ -2263,7 +2263,7 @@ func fetchChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) err
} }
r := bytes.NewReader(revBytes) r := bytes.NewReader(revBytes)
err := readElements( err := ReadElements(
r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer,
&channel.RevocationStore, &channel.RevocationStore,
) )
@ -2279,7 +2279,7 @@ func fetchChanRevocationState(chanBucket *bolt.Bucket, channel *OpenChannel) err
// Otherwise we'll read the next revocation for the remote party which // Otherwise we'll read the next revocation for the remote party which
// is always the last item within the buffer. // is always the last item within the buffer.
return readElements(r, &channel.RemoteNextRevocation) return ReadElements(r, &channel.RemoteNextRevocation)
} }
func deleteOpenChannel(chanBucket *bolt.Bucket, chanPointBytes []byte) error { func deleteOpenChannel(chanBucket *bolt.Bucket, chanPointBytes []byte) error {

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"net"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -43,11 +44,24 @@ func readOutpoint(r io.Reader, o *wire.OutPoint) error {
return nil return nil
} }
// writeElement is a one-stop shop to write the big endian representation of // UnknownElementType is an error returned when the codec is unable to encode or
// decode a particular type.
type UnknownElementType struct {
method string
element interface{}
}
// Error returns the name of the method that encountered the error, as well as
// the type that was unsupported.
func (e UnknownElementType) Error() string {
return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element)
}
// WriteElement is a one-stop shop to write the big endian representation of
// any element which is to be serialized for storage on disk. The passed // any element which is to be serialized for storage on disk. The passed
// io.Writer should be backed by an appropriately sized byte slice, or be able // io.Writer should be backed by an appropriately sized byte slice, or be able
// to dynamically expand to accommodate additional data. // to dynamically expand to accommodate additional data.
func writeElement(w io.Writer, element interface{}) error { func WriteElement(w io.Writer, element interface{}) error {
switch e := element.(type) { switch e := element.(type) {
case keychain.KeyDescriptor: case keychain.KeyDescriptor:
if err := binary.Write(w, byteOrder, e.Family); err != nil { if err := binary.Write(w, byteOrder, e.Family); err != nil {
@ -61,7 +75,7 @@ func writeElement(w io.Writer, element interface{}) error {
if err := binary.Write(w, byteOrder, true); err != nil { if err := binary.Write(w, byteOrder, true); err != nil {
} }
return writeElement(w, e.PubKey) return WriteElement(w, e.PubKey)
} }
return binary.Write(w, byteOrder, false) return binary.Write(w, byteOrder, false)
@ -163,18 +177,34 @@ func writeElement(w io.Writer, element interface{}) error {
return err return err
} }
case net.Addr:
if err := serializeAddr(w, e); err != nil {
return err
}
case []net.Addr:
if err := WriteElement(w, uint32(len(e))); err != nil {
return err
}
for _, addr := range e {
if err := serializeAddr(w, addr); err != nil {
return err
}
}
default: default:
return fmt.Errorf("Unknown type in writeElement: %T", e) return UnknownElementType{"WriteElement", e}
} }
return nil return nil
} }
// writeElements is writes each element in the elements slice to the passed // WriteElements is writes each element in the elements slice to the passed
// io.Writer using writeElement. // io.Writer using WriteElement.
func writeElements(w io.Writer, elements ...interface{}) error { func WriteElements(w io.Writer, elements ...interface{}) error {
for _, element := range elements { for _, element := range elements {
err := writeElement(w, element) err := WriteElement(w, element)
if err != nil { if err != nil {
return err return err
} }
@ -182,9 +212,9 @@ func writeElements(w io.Writer, elements ...interface{}) error {
return nil return nil
} }
// readElement is a one-stop utility function to deserialize any datastructure // ReadElement is a one-stop utility function to deserialize any datastructure
// encoded using the serialization format of the database. // encoded using the serialization format of the database.
func readElement(r io.Reader, element interface{}) error { func ReadElement(r io.Reader, element interface{}) error {
switch e := element.(type) { switch e := element.(type) {
case *keychain.KeyDescriptor: case *keychain.KeyDescriptor:
if err := binary.Read(r, byteOrder, &e.Family); err != nil { if err := binary.Read(r, byteOrder, &e.Family); err != nil {
@ -200,7 +230,7 @@ func readElement(r io.Reader, element interface{}) error {
} }
if hasPubKey { if hasPubKey {
return readElement(r, &e.PubKey) return ReadElement(r, &e.PubKey)
} }
case *ChannelType: case *ChannelType:
@ -342,19 +372,41 @@ func readElement(r io.Reader, element interface{}) error {
return err return err
} }
case *net.Addr:
addr, err := deserializeAddr(r)
if err != nil {
return err
}
*e = addr
case *[]net.Addr:
var numAddrs uint32
if err := ReadElement(r, &numAddrs); err != nil {
return err
}
*e = make([]net.Addr, numAddrs)
for i := uint32(0); i < numAddrs; i++ {
addr, err := deserializeAddr(r)
if err != nil {
return err
}
(*e)[i] = addr
}
default: default:
return fmt.Errorf("Unknown type in readElement: %T", e) return UnknownElementType{"ReadElement", e}
} }
return nil return nil
} }
// readElements deserializes a variable number of elements into the passed // ReadElements deserializes a variable number of elements into the passed
// io.Reader, with each element being deserialized according to the readElement // io.Reader, with each element being deserialized according to the ReadElement
// function. // function.
func readElements(r io.Reader, elements ...interface{}) error { func ReadElements(r io.Reader, elements ...interface{}) error {
for _, element := range elements { for _, element := range elements {
err := readElement(r, element) err := ReadElement(r, element)
if err != nil { if err != nil {
return err return err
} }

@ -83,7 +83,7 @@ type ForwardingEvent struct {
// io.Writer, using the expected DB format. Note that the timestamp isn't // io.Writer, using the expected DB format. Note that the timestamp isn't
// serialized as this will be the key value within the bucket. // serialized as this will be the key value within the bucket.
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
return writeElements( return WriteElements(
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
) )
} }
@ -93,7 +93,7 @@ func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
// won't be decoded, as the caller is expected to set this due to the bucket // won't be decoded, as the caller is expected to set this due to the bucket
// structure of the forwarding log. // structure of the forwarding log.
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
return readElements( return ReadElements(
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
) )
} }