diff --git a/channeldb/channel.go b/channeldb/channel.go index 3bd014ae..d36ded21 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1965,12 +1965,12 @@ func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) { return logUpdates, nil } -func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { +func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl if err := serializeChanCommit(w, &diff.Commitment); err != nil { return err } - if err := diff.CommitSig.Encode(w, 0); err != nil { + if err := WriteElements(w, diff.CommitSig); err != nil { return err } @@ -2016,10 +2016,16 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { return nil, err } - d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { return nil, err } + commitSig, ok := msg.(*lnwire.CommitSig) + if !ok { + return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ + "read: %T", msg) + } + d.CommitSig = commitSig d.LogUpdates, err = deserializeLogUpdates(r) if err != nil { diff --git a/channeldb/codec.go b/channeldb/codec.go index f6903175..424f7c6e 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "encoding/binary" "fmt" "io" @@ -178,7 +179,17 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + var msgBuf bytes.Buffer + if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { + return err + } + + msgLen := uint16(len(msgBuf.Bytes())) + if err := WriteElements(w, msgLen); err != nil { + return err + } + + if _, err := w.Write(msgBuf.Bytes()); err != nil { return err } @@ -394,7 +405,13 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + var msgLen uint16 + if err := ReadElement(r, &msgLen); err != nil { + return err + } + + msgReader := io.LimitReader(r, int64(msgLen)) + msg, err := lnwire.ReadMessage(msgReader, 0) if err != nil { return err } diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index cd5fe0a5..06345aff 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -61,28 +61,15 @@ type networkResult struct { // serializeNetworkResult serializes the networkResult. func serializeNetworkResult(w io.Writer, n *networkResult) error { - if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil { - return err - } - - return channeldb.WriteElements(w, n.unencrypted, n.isResolution) + return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution) } // deserializeNetworkResult deserializes the networkResult. func deserializeNetworkResult(r io.Reader) (*networkResult, error) { - var ( - err error - ) - n := &networkResult{} - n.msg, err = lnwire.ReadMessage(r, 0) - if err != nil { - return nil, err - } - if err := channeldb.ReadElements(r, - &n.unencrypted, &n.isResolution, + &n.msg, &n.unencrypted, &n.isResolution, ); err != nil { return nil, err }