diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 4dab6f4c..1f4c1d35 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -109,14 +109,6 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // as that was just the encoding type. queryBody = queryBody[1:] - // If after extracting the encoding type, then number of remaining - // bytes instead a whole multiple of the size of an encoded short - // channel ID (8 bytes), then we'll return a parsing error. - if len(queryBody)%8 != 0 { - return 0, nil, fmt.Errorf("whole number of short chan ID's "+ - "cannot be encoded in len=%v", len(queryBody)) - } - // Otherwise, depending on the encoding type, we'll decode the encode // short channel ID's in a different manner. switch encodingType { @@ -124,6 +116,16 @@ func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, err // In this encoding, we'll simply read a sort array of encoded short // channel ID's from the buffer. case EncodingSortedPlain: + // If after extracting the encoding type, then number of + // remaining bytes instead a whole multiple of the size of an + // encoded short channel ID (8 bytes), then we'll return a + // parsing error. + if len(queryBody)%8 != 0 { + return 0, nil, fmt.Errorf("whole number of short "+ + "chan ID's cannot be encoded in len=%v", + len(queryBody)) + } + // As each short channel ID is encoded as 8 bytes, we can // compute the number of bytes encoded based on the size of the // query body.