channeldb+htlcswitch: write wire messages using length prefix

In this commit, we modify the way we write wire messages across the
entire database. We'll now ensure that we always write wire messages
with a length prefix. We update the `codec.go` file to always write a 2
byte length prefix, this affects the way we write the `CommitDiff` and
`LogUpdates` struct to disk, and the network results bucket in the
switch as it includes a wire message.
This commit is contained in:
Johan T. Halseth 2021-02-19 12:03:01 +01:00
parent 4133b4d04e
commit db66fef6cc
No known key found for this signature in database
GPG Key ID: 15BAADA29DA20D26
3 changed files with 31 additions and 21 deletions

@ -1965,12 +1965,12 @@ func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) {
return logUpdates, nil return logUpdates, nil
} }
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl
if err := serializeChanCommit(w, &diff.Commitment); err != nil { if err := serializeChanCommit(w, &diff.Commitment); err != nil {
return err return err
} }
if err := diff.CommitSig.Encode(w, 0); err != nil { if err := WriteElements(w, diff.CommitSig); err != nil {
return err return err
} }
@ -2016,10 +2016,16 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
return nil, err return nil, err
} }
d.CommitSig = &lnwire.CommitSig{} var msg lnwire.Message
if err := d.CommitSig.Decode(r, 0); err != nil { if err := ReadElements(r, &msg); err != nil {
return nil, err return nil, err
} }
commitSig, ok := msg.(*lnwire.CommitSig)
if !ok {
return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+
"read: %T", msg)
}
d.CommitSig = commitSig
d.LogUpdates, err = deserializeLogUpdates(r) d.LogUpdates, err = deserializeLogUpdates(r)
if err != nil { if err != nil {

@ -1,6 +1,7 @@
package channeldb package channeldb
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
@ -178,7 +179,17 @@ func WriteElement(w io.Writer, element interface{}) error {
} }
case lnwire.Message: case lnwire.Message:
if _, err := lnwire.WriteMessage(w, e, 0); err != nil { var msgBuf bytes.Buffer
if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil {
return err
}
msgLen := uint16(len(msgBuf.Bytes()))
if err := WriteElements(w, msgLen); err != nil {
return err
}
if _, err := w.Write(msgBuf.Bytes()); err != nil {
return err return err
} }
@ -394,7 +405,13 @@ func ReadElement(r io.Reader, element interface{}) error {
*e = bytes *e = bytes
case *lnwire.Message: case *lnwire.Message:
msg, err := lnwire.ReadMessage(r, 0) var msgLen uint16
if err := ReadElement(r, &msgLen); err != nil {
return err
}
msgReader := io.LimitReader(r, int64(msgLen))
msg, err := lnwire.ReadMessage(msgReader, 0)
if err != nil { if err != nil {
return err return err
} }

@ -61,28 +61,15 @@ type networkResult struct {
// serializeNetworkResult serializes the networkResult. // serializeNetworkResult serializes the networkResult.
func serializeNetworkResult(w io.Writer, n *networkResult) error { func serializeNetworkResult(w io.Writer, n *networkResult) error {
if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil { return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution)
return err
}
return channeldb.WriteElements(w, n.unencrypted, n.isResolution)
} }
// deserializeNetworkResult deserializes the networkResult. // deserializeNetworkResult deserializes the networkResult.
func deserializeNetworkResult(r io.Reader) (*networkResult, error) { func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
var (
err error
)
n := &networkResult{} n := &networkResult{}
n.msg, err = lnwire.ReadMessage(r, 0)
if err != nil {
return nil, err
}
if err := channeldb.ReadElements(r, if err := channeldb.ReadElements(r,
&n.unencrypted, &n.isResolution, &n.msg, &n.unencrypted, &n.isResolution,
); err != nil { ); err != nil {
return nil, err return nil, err
} }