diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index b9e2415a..67af3696 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -111,7 +111,7 @@ func TestLightningWireProtocol(t *testing.T) { // Finally, we'll deserialize the message from the written // buffer, and finally assert that the messages are equal. - _, newMsg, err := ReadMessage(&b, 0) + newMsg, err := ReadMessage(&b, 0) if err != nil { t.Fatalf("unable to read msg: %v", err) return false diff --git a/lnwire/message.go b/lnwire/message.go index 59e230ed..3b604b66 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -167,21 +167,14 @@ func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) { return totalBytes, err } -// ReadMessage reads, validates, and parses the next bitcoin Message from r for -// the provided protocol version. It returns the number of bytes read in -// addition to the parsed Message and raw bytes which comprise the message. -func ReadMessage(r io.Reader, pver uint32) (int, Message, error) { - // TODO(roasbeef): need to explicitly enforce max message payload, or - // just allow it to be done by the MaxPayloadLength? - totalBytes := 0 - +// ReadMessage reads, validates, and parses the next Lightning message from r +// for the provided protocol version. +func ReadMessage(r io.Reader, pver uint32) (Message, error) { // First, we'll read out the first two bytes of the message so we can // create the proper empty message. var mType [2]byte - n, err := io.ReadFull(r, mType[:]) - totalBytes += n - if err != nil { - return totalBytes, nil, err + if _, err := io.ReadFull(r, mType[:]); err != nil { + return nil, err } msgType := MessageType(binary.BigEndian.Uint16(mType[:])) @@ -190,12 +183,11 @@ func ReadMessage(r io.Reader, pver uint32) (int, Message, error) { // empty message type and decode the message into it. msg, err := makeEmptyMessage(msgType) if err != nil { - return totalBytes, nil, err + return nil, err } if err := msg.Decode(r, pver); err != nil { - return totalBytes, nil, err + return nil, err } - totalBytes += int(msg.MaxPayloadLength(pver)) - return totalBytes, msg, nil + return msg, nil }