lnwire: modify ReadMessage to no longer return the total bytes read
This commit modifies ReadMessage to no longer return the total bytes read as this value will now be calculated at a higher level. The io.Reader that’s passed to ReadMessage is expected to contain the _entire_ message rather than be a pointer into a stream that contains the message itself.
This commit is contained in:
parent
38d3c72dc8
commit
aa2ca81762
@ -111,7 +111,7 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
|
||||
// Finally, we'll deserialize the message from the written
|
||||
// buffer, and finally assert that the messages are equal.
|
||||
_, newMsg, err := ReadMessage(&b, 0)
|
||||
newMsg, err := ReadMessage(&b, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read msg: %v", err)
|
||||
return false
|
||||
|
@ -167,21 +167,14 @@ func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) {
|
||||
return totalBytes, err
|
||||
}
|
||||
|
||||
// ReadMessage reads, validates, and parses the next bitcoin Message from r for
|
||||
// the provided protocol version. It returns the number of bytes read in
|
||||
// addition to the parsed Message and raw bytes which comprise the message.
|
||||
func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
|
||||
// TODO(roasbeef): need to explicitly enforce max message payload, or
|
||||
// just allow it to be done by the MaxPayloadLength?
|
||||
totalBytes := 0
|
||||
|
||||
// ReadMessage reads, validates, and parses the next Lightning message from r
|
||||
// for the provided protocol version.
|
||||
func ReadMessage(r io.Reader, pver uint32) (Message, error) {
|
||||
// First, we'll read out the first two bytes of the message so we can
|
||||
// create the proper empty message.
|
||||
var mType [2]byte
|
||||
n, err := io.ReadFull(r, mType[:])
|
||||
totalBytes += n
|
||||
if err != nil {
|
||||
return totalBytes, nil, err
|
||||
if _, err := io.ReadFull(r, mType[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgType := MessageType(binary.BigEndian.Uint16(mType[:]))
|
||||
@ -190,12 +183,11 @@ func ReadMessage(r io.Reader, pver uint32) (int, Message, error) {
|
||||
// empty message type and decode the message into it.
|
||||
msg, err := makeEmptyMessage(msgType)
|
||||
if err != nil {
|
||||
return totalBytes, nil, err
|
||||
return nil, err
|
||||
}
|
||||
if err := msg.Decode(r, pver); err != nil {
|
||||
return totalBytes, nil, err
|
||||
return nil, err
|
||||
}
|
||||
totalBytes += int(msg.MaxPayloadLength(pver))
|
||||
|
||||
return totalBytes, msg, nil
|
||||
return msg, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user