channeldb+htlcswitch: write wire messages using length prefix
In this commit, we modify the way we write wire messages across the entire database. We'll now ensure that we always write wire messages with a length prefix. We update the `codec.go` file to always write a 2 byte length prefix, this affects the way we write the `CommitDiff` and `LogUpdates` struct to disk, and the network results bucket in the switch as it includes a wire message.
This commit is contained in:
parent
4133b4d04e
commit
db66fef6cc
@ -1965,12 +1965,12 @@ func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) {
|
|||||||
return logUpdates, nil
|
return logUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error {
|
func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl
|
||||||
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
|
if err := serializeChanCommit(w, &diff.Commitment); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := diff.CommitSig.Encode(w, 0); err != nil {
|
if err := WriteElements(w, diff.CommitSig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2016,10 +2016,16 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.CommitSig = &lnwire.CommitSig{}
|
var msg lnwire.Message
|
||||||
if err := d.CommitSig.Decode(r, 0); err != nil {
|
if err := ReadElements(r, &msg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
commitSig, ok := msg.(*lnwire.CommitSig)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+
|
||||||
|
"read: %T", msg)
|
||||||
|
}
|
||||||
|
d.CommitSig = commitSig
|
||||||
|
|
||||||
d.LogUpdates, err = deserializeLogUpdates(r)
|
d.LogUpdates, err = deserializeLogUpdates(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package channeldb
|
package channeldb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -178,7 +179,17 @@ func WriteElement(w io.Writer, element interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case lnwire.Message:
|
case lnwire.Message:
|
||||||
if _, err := lnwire.WriteMessage(w, e, 0); err != nil {
|
var msgBuf bytes.Buffer
|
||||||
|
if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgLen := uint16(len(msgBuf.Bytes()))
|
||||||
|
if err := WriteElements(w, msgLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := w.Write(msgBuf.Bytes()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -394,7 +405,13 @@ func ReadElement(r io.Reader, element interface{}) error {
|
|||||||
*e = bytes
|
*e = bytes
|
||||||
|
|
||||||
case *lnwire.Message:
|
case *lnwire.Message:
|
||||||
msg, err := lnwire.ReadMessage(r, 0)
|
var msgLen uint16
|
||||||
|
if err := ReadElement(r, &msgLen); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
msgReader := io.LimitReader(r, int64(msgLen))
|
||||||
|
msg, err := lnwire.ReadMessage(msgReader, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -61,28 +61,15 @@ type networkResult struct {
|
|||||||
|
|
||||||
// serializeNetworkResult serializes the networkResult.
|
// serializeNetworkResult serializes the networkResult.
|
||||||
func serializeNetworkResult(w io.Writer, n *networkResult) error {
|
func serializeNetworkResult(w io.Writer, n *networkResult) error {
|
||||||
if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil {
|
return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return channeldb.WriteElements(w, n.unencrypted, n.isResolution)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// deserializeNetworkResult deserializes the networkResult.
|
// deserializeNetworkResult deserializes the networkResult.
|
||||||
func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
|
func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
|
||||||
var (
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
n := &networkResult{}
|
n := &networkResult{}
|
||||||
|
|
||||||
n.msg, err = lnwire.ReadMessage(r, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := channeldb.ReadElements(r,
|
if err := channeldb.ReadElements(r,
|
||||||
&n.unencrypted, &n.isResolution,
|
&n.msg, &n.unencrypted, &n.isResolution,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user