diff --git a/peer.go b/peer.go index 8bf81859..927247d1 100644 --- a/peer.go +++ b/peer.go @@ -13,6 +13,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/brontide" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet" @@ -373,10 +374,25 @@ func (p *peer) String() string { // readNextMessage reads, and returns the next message on the wire along with // any additional raw payload. func (p *peer) readNextMessage() (lnwire.Message, error) { - // TODO(roasbeef): should take diff of what was read - // * also switch to message oriented reading - n, nextMsg, err := lnwire.ReadMessage(p.conn, 0) - atomic.AddUint64(&p.bytesReceived, uint64(n)) + noiseConn, ok := p.conn.(*brontide.Conn) + if !ok { + return nil, fmt.Errorf("brontide.Conn required to read messages") + } + + // First we'll read the next _full_ message. We do this rather than + // reading incrementally from the stream as the Lightning wire protocol + // is message oriented and allows nodes to pad on additional data to + // the message stream. + rawMsg, err := noiseConn.ReadNextMessage() + atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg))) + if err != nil { + return nil, err + } + + // Next, create a new io.Reader implementation from the raw message, + // and use this to decode the message directly from. + msgReader := bytes.NewReader(rawMsg) + nextMsg, err := lnwire.ReadMessage(msgReader, 0) if err != nil { return nil, err } @@ -586,9 +602,19 @@ func (p *peer) writeMessage(msg lnwire.Message) error { // TODO(roasbeef): add message summaries p.logWireMessage(msg, false) - n, err := lnwire.WriteMessage(p.conn, msg, 0) + // As the Lightning wire protocol is fully message oriented, we only + // allows one wire message per outer encapsulated crypto message. So + // we'll create a temporary buffer to write the message directly to. + var msgPayload [lnwire.MaxMessagePayload]byte + b := bytes.NewBuffer(msgPayload[0:0:len(msgPayload)]) + + // With the temp buffer created and sliced properly (length zero, full + // capacity), we'll now encode the message directly into this buffer. + n, err := lnwire.WriteMessage(b, msg, 0) atomic.AddUint64(&p.bytesSent, uint64(n)) + // Finally, write the message itself in a single swoop. + _, err = p.conn.Write(b.Bytes()) return err }