From db66fef6cc0ea4efccc46544004557fb2d17400e Mon Sep 17 00:00:00 2001 From: "Johan T. Halseth" Date: Fri, 19 Feb 2021 12:03:01 +0100 Subject: [PATCH] channeldb+htlcswitch: write wire messages using length prefix In this commit, we modify the way we write wire messages across the entire database. We'll now ensure that we always write wire messages with a length prefix. We update the `codec.go` file to always write a 2 byte length prefix, this affects the way we write the `CommitDiff` and `LogUpdates` struct to disk, and the network results bucket in the switch as it includes a wire message. --- channeldb/channel.go | 14 ++++++++++---- channeldb/codec.go | 21 +++++++++++++++++++-- htlcswitch/payment_result.go | 17 ++--------------- 3 files changed, 31 insertions(+), 21 deletions(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index 3bd014ae..d36ded21 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1965,12 +1965,12 @@ func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) { return logUpdates, nil } -func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { +func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl if err := serializeChanCommit(w, &diff.Commitment); err != nil { return err } - if err := diff.CommitSig.Encode(w, 0); err != nil { + if err := WriteElements(w, diff.CommitSig); err != nil { return err } @@ -2016,10 +2016,16 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { return nil, err } - d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { return nil, err } + commitSig, ok := msg.(*lnwire.CommitSig) + if !ok { + return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ + "read: %T", msg) + } + d.CommitSig = commitSig d.LogUpdates, err = deserializeLogUpdates(r) if err != nil { diff --git a/channeldb/codec.go b/channeldb/codec.go index f6903175..424f7c6e 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "encoding/binary" "fmt" "io" @@ -178,7 +179,17 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + var msgBuf bytes.Buffer + if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { + return err + } + + msgLen := uint16(len(msgBuf.Bytes())) + if err := WriteElements(w, msgLen); err != nil { + return err + } + + if _, err := w.Write(msgBuf.Bytes()); err != nil { return err } @@ -394,7 +405,13 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + var msgLen uint16 + if err := ReadElement(r, &msgLen); err != nil { + return err + } + + msgReader := io.LimitReader(r, int64(msgLen)) + msg, err := lnwire.ReadMessage(msgReader, 0) if err != nil { return err } diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index cd5fe0a5..06345aff 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -61,28 +61,15 @@ type networkResult struct { // serializeNetworkResult serializes the networkResult. func serializeNetworkResult(w io.Writer, n *networkResult) error { - if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil { - return err - } - - return channeldb.WriteElements(w, n.unencrypted, n.isResolution) + return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution) } // deserializeNetworkResult deserializes the networkResult. func deserializeNetworkResult(r io.Reader) (*networkResult, error) { - var ( - err error - ) - n := &networkResult{} - n.msg, err = lnwire.ReadMessage(r, 0) - if err != nil { - return nil, err - } - if err := channeldb.ReadElements(r, - &n.unencrypted, &n.isResolution, + &n.msg, &n.unencrypted, &n.isResolution, ); err != nil { return nil, err }