From aa2ca81762da3a383df95cd0c63f633bcc45d43f Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Thu, 20 Apr 2017 15:41:43 -0700 Subject: [PATCH] lnwire: modify ReadMessage to no longer return the total bytes read MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit modifies ReadMessage to no longer return the total bytes read as this value will now be calculated at a higher level. The io.Reader that’s passed to ReadMessage is expected to contain the _entire_ message rather than be a pointer into a stream that contains the message itself. --- lnwire/lnwire_test.go | 2 +- lnwire/message.go | 24 ++++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index b9e2415a..67af3696 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -111,7 +111,7 @@ func TestLightningWireProtocol(t *testing.T) { // Finally, we'll deserialize the message from the written // buffer, and finally assert that the messages are equal. - _, newMsg, err := ReadMessage(&b, 0) + newMsg, err := ReadMessage(&b, 0) if err != nil { t.Fatalf("unable to read msg: %v", err) return false diff --git a/lnwire/message.go b/lnwire/message.go index 59e230ed..3b604b66 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -167,21 +167,14 @@ func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) { return totalBytes, err } -// ReadMessage reads, validates, and parses the next bitcoin Message from r for -// the provided protocol version. It returns the number of bytes read in -// addition to the parsed Message and raw bytes which comprise the message. -func ReadMessage(r io.Reader, pver uint32) (int, Message, error) { - // TODO(roasbeef): need to explicitly enforce max message payload, or - // just allow it to be done by the MaxPayloadLength? - totalBytes := 0 - +// ReadMessage reads, validates, and parses the next Lightning message from r +// for the provided protocol version. +func ReadMessage(r io.Reader, pver uint32) (Message, error) { // First, we'll read out the first two bytes of the message so we can // create the proper empty message. var mType [2]byte - n, err := io.ReadFull(r, mType[:]) - totalBytes += n - if err != nil { - return totalBytes, nil, err + if _, err := io.ReadFull(r, mType[:]); err != nil { + return nil, err } msgType := MessageType(binary.BigEndian.Uint16(mType[:])) @@ -190,12 +183,11 @@ func ReadMessage(r io.Reader, pver uint32) (int, Message, error) { // empty message type and decode the message into it. msg, err := makeEmptyMessage(msgType) if err != nil { - return totalBytes, nil, err + return nil, err } if err := msg.Decode(r, pver); err != nil { - return totalBytes, nil, err + return nil, err } - totalBytes += int(msg.MaxPayloadLength(pver)) - return totalBytes, msg, nil + return msg, nil }