diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index a0e28f4d..fd959c43 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -181,7 +181,10 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) } - var shortChanIDs []ShortChannelID + var ( + shortChanIDs []ShortChannelID + lastChanID ShortChannelID + ) for { // We'll now attempt to read the next short channel ID // encoded in the payload. @@ -208,6 +211,18 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // We successfully read the next ID, so well collect // that in the set of final ID's to return. shortChanIDs = append(shortChanIDs, cid) + + // Finally, we'll ensure that this short chan ID is + // greater than the last one. This is a requirement + // within the encoding, and if violated can aide us in + // detecting malicious payloads. + if cid.ToUint64() <= lastChanID.ToUint64() { + return 0, nil, fmt.Errorf("current sid of %v "+ + "isn't greater than last sid of %v", cid, + lastChanID) + } + + lastChanID = cid } default: