diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 212ead8d..39209a5c 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -78,9 +78,14 @@ func (a addressType) AddrLen() uint16 { // // TODO(roasbeef): this should eventually draw from a buffer pool for // serialization. -// TODO(roasbeef): switch to var-ints for all? func writeElement(w io.Writer, element interface{}) error { switch e := element.(type) { + case ShortChanIDEncoding: + var b [1]byte + b[0] = uint8(e) + if _, err := w.Write(b[:]); err != nil { + return err + } case uint8: var b [1]byte b[0] = e @@ -390,6 +395,12 @@ func writeElements(w io.Writer, elements ...interface{}) error { func readElement(r io.Reader, element interface{}) error { var err error switch e := element.(type) { + case *ShortChanIDEncoding: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = ShortChanIDEncoding(b[0]) case *uint8: var b [1]uint8 if _, err := r.Read(b[:]); err != nil { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e8cd720c..fd214308 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -11,6 +11,7 @@ import ( "reflect" "testing" "testing/quick" + "time" "github.com/davecgh/go-spew/spew" "github.com/roasbeef/btcd/btcec" @@ -553,6 +554,57 @@ func TestLightningWireProtocol(t *testing.T) { } } + v[0] = reflect.ValueOf(req) + }, + MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { + req := QueryShortChanIDs{ + // TODO(roasbeef): later alternate encoding types + EncodingType: EncodingSortedPlain, + } + + if _, err := rand.Read(req.ChainHash[:]); err != nil { + t.Fatalf("unable to read chain hash: %v", err) + return + } + + numChanIDs := rand.Int31n(5000) + + req.ShortChanIDs = make([]ShortChannelID, numChanIDs) + for i := int32(0); i < numChanIDs; i++ { + req.ShortChanIDs[i] = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + } + + v[0] = reflect.ValueOf(req) + }, + MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { + req := ReplyChannelRange{ + QueryChannelRange: QueryChannelRange{ + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), + }, + } + + if _, err := rand.Read(req.ChainHash[:]); err != nil { + t.Fatalf("unable to read chain hash: %v", err) + return + } + + req.Complete = uint8(r.Int31n(2)) + + // TODO(roasbeef): later alternate encoding types + req.EncodingType = EncodingSortedPlain + + numChanIDs := rand.Int31n(5000) + + req.ShortChanIDs = make([]ShortChannelID, numChanIDs) + for i := int32(0); i < numChanIDs; i++ { + req.ShortChanIDs[i] = NewShortChanIDFromInt( + uint64(r.Int63()), + ) + } + v[0] = reflect.ValueOf(req) }, } @@ -705,6 +757,36 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + msgType: MsgGossipTimestampRange, + scenario: func(m GossipTimestampRange) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgQueryShortChanIDs, + scenario: func(m QueryShortChanIDs) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgReplyShortChanIDsEnd, + scenario: func(m ReplyShortChanIDsEnd) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgQueryChannelRange, + scenario: func(m QueryChannelRange) bool { + return mainScenario(&m) + }, + }, + { + msgType: MsgReplyChannelRange, + scenario: func(m ReplyChannelRange) bool { + return mainScenario(&m) + }, + }, } for _, test := range tests { var config *quick.Config @@ -726,3 +808,7 @@ func TestLightningWireProtocol(t *testing.T) { } } + +func init() { + rand.Seed(time.Now().Unix()) +} diff --git a/lnwire/message.go b/lnwire/message.go index a10bba29..b5c27339 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -49,6 +49,11 @@ const ( MsgNodeAnnouncement = 257 MsgChannelUpdate = 258 MsgAnnounceSignatures = 259 + MsgQueryShortChanIDs = 261 + MsgReplyShortChanIDsEnd = 262 + MsgQueryChannelRange = 263 + MsgReplyChannelRange = 264 + MsgGossipTimestampRange = 265 ) // String return the string representation of message type. @@ -100,6 +105,16 @@ func (t MessageType) String() string { return "Pong" case MsgUpdateFee: return "UpdateFee" + case MsgQueryShortChanIDs: + return "QueryShortChanIDs" + case MsgReplyShortChanIDsEnd: + return "ReplyShortChanIDsEnd" + case MsgQueryChannelRange: + return "QueryChannelRange" + case MsgReplyChannelRange: + return "ReplyChannelRange" + case MsgGossipTimestampRange: + return "GossipTimestampRange" default: return "" } @@ -191,6 +206,16 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &AnnounceSignatures{} case MsgPong: msg = &Pong{} + case MsgQueryShortChanIDs: + msg = &QueryShortChanIDs{} + case MsgReplyShortChanIDsEnd: + msg = &ReplyShortChanIDsEnd{} + case MsgQueryChannelRange: + msg = &QueryChannelRange{} + case MsgReplyChannelRange: + msg = &ReplyChannelRange{} + case MsgGossipTimestampRange: + msg = &GossipTimestampRange{} default: return nil, &UnknownMessage{msgType} }