lnwire: modify ReadMessage to no longer return the total bytes read

This commit modifies ReadMessage to no longer return the total bytes
read as this value will now be calculated at a higher level. The
io.Reader that’s passed to ReadMessage is expected to contain the
_entire_ message rather than be a pointer into a stream that contains
the message itself.
This commit is contained in:
Olaoluwa Osuntokun 2017-04-20 15:41:43 -07:00
parent 38d3c72dc8
commit aa2ca81762
No known key found for this signature in database
GPG Key ID: 9CC5B105D03521A2
2 changed files with 9 additions and 17 deletions

@ -111,7 +111,7 @@ func TestLightningWireProtocol(t *testing.T) {
// Finally, we'll deserialize the message from the written
// buffer, and finally assert that the messages are equal.
_, newMsg, err := ReadMessage(&b, 0)
newMsg, err := ReadMessage(&b, 0)
if err != nil {
t.Fatalf("unable to read msg: %v", err)
return false

@ -167,21 +167,14 @@ func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
return totalBytes, err
}
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
// the provided protocol version. It returns the number of bytes read in
// addition to the parsed Message and raw bytes which comprise the message.
func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
// TODO(roasbeef): need to explicitly enforce max message payload, or
// just allow it to be done by the MaxPayloadLength?
totalBytes := 0
// ReadMessage reads, validates, and parses the next Lightning message from r
// for the provided protocol version.
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
// First, we'll read out the first two bytes of the message so we can
// create the proper empty message.
var mType [2]byte
n, err := io.ReadFull(r, mType[:])
totalBytes += n
if err != nil {
return totalBytes, nil, err
if _, err := io.ReadFull(r, mType[:]); err != nil {
return nil, err
}
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
@ -190,12 +183,11 @@ func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
// empty message type and decode the message into it.
msg, err := makeEmptyMessage(msgType)
if err != nil {
return totalBytes, nil, err
return nil, err
}
if err := msg.Decode(r, pver); err != nil {
return totalBytes, nil, err
return nil, err
}
totalBytes += int(msg.MaxPayloadLength(pver))
return totalBytes, msg, nil
return msg, nil
}