From 898d5ffccf88a2551209c891c82648e52602d6a1 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 16 Apr 2018 18:46:15 -0700 Subject: [PATCH] lnwire: add new QueryShortChanIDs gossip query msg --- lnwire/query_short_chan_ids.go | 233 +++++++++++++++++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 lnwire/query_short_chan_ids.go diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go new file mode 100644 index 00000000..4dab6f4c --- /dev/null +++ b/lnwire/query_short_chan_ids.go @@ -0,0 +1,233 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + "sort" + + "github.com/roasbeef/btcd/chaincfg/chainhash" +) + +// ShortChanIDEncoding is an enum-like type that represents exactly how a set +// of short channel ID's is encoded on the wire. The set of encodings allows us +// to take advantage of the structure of a list of short channel ID's to +// achieving a high degree of compression. +type ShortChanIDEncoding uint8 + +const ( + // EncodingSortedPlain signals that the set of short channel ID's is + // encoded using the regular encoding, in a sorted order. + EncodingSortedPlain ShortChanIDEncoding = 0 + + // TODO(roasbeef): list max number of short chan id's that are able to + // use +) + +// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we +// came across an unknown short channel ID encoding, and therefore were unable +// to continue parsing. +func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { + return fmt.Errorf("unknown short chan id encoding: %v", encoding) +} + +// QueryShortChanIDs is a message that allows the sender to query a set of +// channel announcement and channel update messages that correspond to the set +// of encoded short channel ID's. The encoding of the short channel ID's is +// detailed in the query message ensuring that the receiver knows how to +// properly decode each encode short channel ID which may be encoded using a +// compression format. The receiver should respond with a series of channel +// announcement and channel updates, finally sending a ReplyShortChanIDsEnd +// message. +type QueryShortChanIDs struct { + // ChainHash denotes the target chain that we're querying for the + // channel channel ID's of. + ChainHash chainhash.Hash + + // EncodingType is a signal to the receiver of the message that + // indicates exactly how the set of short channel ID's that follow have + // been encoded. + EncodingType ShortChanIDEncoding + + // ShortChanIDs is a slice of decoded short channel ID's. + ShortChanIDs []ShortChannelID +} + +// NewQueryShortChanIDs creates a new QueryShortChanIDs message. +func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, + s []ShortChannelID) *QueryShortChanIDs { + + return &QueryShortChanIDs{ + ChainHash: h, + EncodingType: e, + ShortChanIDs: s, + } +} + +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.Message interface. +var _ Message = (*QueryShortChanIDs)(nil) + +// Decode deserializes a serialized QueryShortChanIDs message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { + err := readElements(r, q.ChainHash[:]) + if err != nil { + return err + } + + q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + + return err +} + +// decodeShortChanIDs decodes a set of short channel ID's that have been +// encoded. The first byte of the body details how the short chan ID's were +// encoded. We'll use this type to govern exactly how we go about encoding the +// set of short channel ID's. +func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { + // First, we'll attempt to read the number of bytes in the body of the + // set of encoded short channel ID's. + var numBytesResp uint16 + err := readElements(r, &numBytesResp) + if err != nil { + return 0, nil, err + } + + queryBody := make([]byte, numBytesResp) + if _, err := io.ReadFull(r, queryBody); err != nil { + return 0, nil, err + } + + // The first byte is the encoding type, so we'll extract that so we can + // continue our parsing. + encodingType := ShortChanIDEncoding(queryBody[0]) + + // Before continuing, we'll snip off the first byte of the query body + // as that was just the encoding type. + queryBody = queryBody[1:] + + // If after extracting the encoding type, then number of remaining + // bytes instead a whole multiple of the size of an encoded short + // channel ID (8 bytes), then we'll return a parsing error. + if len(queryBody)%8 != 0 { + return 0, nil, fmt.Errorf("whole number of short chan ID's "+ + "cannot be encoded in len=%v", len(queryBody)) + } + + // Otherwise, depending on the encoding type, we'll decode the encode + // short channel ID's in a different manner. + switch encodingType { + + // In this encoding, we'll simply read a sort array of encoded short + // channel ID's from the buffer. + case EncodingSortedPlain: + // As each short channel ID is encoded as 8 bytes, we can + // compute the number of bytes encoded based on the size of the + // query body. + numShortChanIDs := len(queryBody) / 8 + shortChanIDs := make([]ShortChannelID, numShortChanIDs) + + // Finally, we'll read out the exact number of short channel + // ID's to conclude our parsing. + bodyReader := bytes.NewReader(queryBody) + for i := 0; i < numShortChanIDs; i++ { + if err := readElements(bodyReader, &shortChanIDs[i]); err != nil { + return 0, nil, fmt.Errorf("unable to parse "+ + "short chan ID: %v", err) + } + } + + return encodingType, shortChanIDs, nil + + default: + // If we've been sent an encoding type that we don't know of, + // then we'll return a parsing error as we can't continue if + // we're unable to encode them. + return 0, nil, ErrUnknownShortChanIDEncoding(encodingType) + } +} + +// Encode serializes the target QueryShortChanIDs into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { + // First, we'll write out the chain hash. + err := writeElements(w, q.ChainHash[:]) + if err != nil { + return err + } + + // Base on our encoding type, we'll write out the set of short channel + // ID's. + return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs) +} + +// encodeShortChanIDs encodes the passed short channel ID's into the passed +// io.Writer, respecting the specified encoding type. +func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, + shortChanIDs []ShortChannelID) error { + + switch encodingType { + + // In this encoding, we'll simply write a sorted array of encoded short + // channel ID's from the buffer. + case EncodingSortedPlain: + // First, we'll write out the number of bytes of the query + // body. We add 1 as the response will have the encoding type + // prepended to it. + numBytesBody := uint16(len(shortChanIDs)*8) + 1 + if err := writeElements(w, numBytesBody); err != nil { + return err + } + + // We'll then write out the encoding that that follows the + // actual encoded short channel ID's. + if err := writeElements(w, encodingType); err != nil { + return err + } + + // Next, we'll ensure that the set of short channel ID's is + // properly sorted in place. + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + + // Now that we know they're sorted, we can write out each short + // channel ID to the buffer. + for _, chanID := range shortChanIDs { + if err := writeElements(w, chanID); err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + return nil + + default: + // If we're trying to encode with an encoding type that we + // don't know of, then we'll return a parsing error as we can't + // continue if we're unable to encode them. + return ErrUnknownShortChanIDEncoding(encodingType) + } +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) MsgType() MessageType { + return MsgQueryShortChanIDs +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// QueryShortChanIDs complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +}