lnwire: update tests and message code definitions for new gossip query msgs

This commit is contained in:
Olaoluwa Osuntokun 2018-04-16 18:47:53 -07:00
parent fa9a012ac6
commit 62df3cbbb8
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
3 changed files with 123 additions and 1 deletions

@ -78,9 +78,14 @@ func (a addressType) AddrLen() uint16 {
// //
// TODO(roasbeef): this should eventually draw from a buffer pool for // TODO(roasbeef): this should eventually draw from a buffer pool for
// serialization. // serialization.
// TODO(roasbeef): switch to var-ints for all?
func writeElement(w io.Writer, element interface{}) error { func writeElement(w io.Writer, element interface{}) error {
switch e := element.(type) { switch e := element.(type) {
case ShortChanIDEncoding:
var b [1]byte
b[0] = uint8(e)
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint8: case uint8:
var b [1]byte var b [1]byte
b[0] = e b[0] = e
@ -390,6 +395,12 @@ func writeElements(w io.Writer, elements ...interface{}) error {
func readElement(r io.Reader, element interface{}) error { func readElement(r io.Reader, element interface{}) error {
var err error var err error
switch e := element.(type) { switch e := element.(type) {
case *ShortChanIDEncoding:
var b [1]uint8
if _, err := r.Read(b[:]); err != nil {
return err
}
*e = ShortChanIDEncoding(b[0])
case *uint8: case *uint8:
var b [1]uint8 var b [1]uint8
if _, err := r.Read(b[:]); err != nil { if _, err := r.Read(b[:]); err != nil {

@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"testing/quick" "testing/quick"
"time"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/btcec"
@ -553,6 +554,57 @@ func TestLightningWireProtocol(t *testing.T) {
} }
} }
v[0] = reflect.ValueOf(req)
},
MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) {
req := QueryShortChanIDs{
// TODO(roasbeef): later alternate encoding types
EncodingType: EncodingSortedPlain,
}
if _, err := rand.Read(req.ChainHash[:]); err != nil {
t.Fatalf("unable to read chain hash: %v", err)
return
}
numChanIDs := rand.Int31n(5000)
req.ShortChanIDs = make([]ShortChannelID, numChanIDs)
for i := int32(0); i < numChanIDs; i++ {
req.ShortChanIDs[i] = NewShortChanIDFromInt(
uint64(r.Int63()),
)
}
v[0] = reflect.ValueOf(req)
},
MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) {
req := ReplyChannelRange{
QueryChannelRange: QueryChannelRange{
FirstBlockHeight: uint32(r.Int31()),
NumBlocks: uint32(r.Int31()),
},
}
if _, err := rand.Read(req.ChainHash[:]); err != nil {
t.Fatalf("unable to read chain hash: %v", err)
return
}
req.Complete = uint8(r.Int31n(2))
// TODO(roasbeef): later alternate encoding types
req.EncodingType = EncodingSortedPlain
numChanIDs := rand.Int31n(5000)
req.ShortChanIDs = make([]ShortChannelID, numChanIDs)
for i := int32(0); i < numChanIDs; i++ {
req.ShortChanIDs[i] = NewShortChanIDFromInt(
uint64(r.Int63()),
)
}
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
}, },
} }
@ -705,6 +757,36 @@ func TestLightningWireProtocol(t *testing.T) {
return mainScenario(&m) return mainScenario(&m)
}, },
}, },
{
msgType: MsgGossipTimestampRange,
scenario: func(m GossipTimestampRange) bool {
return mainScenario(&m)
},
},
{
msgType: MsgQueryShortChanIDs,
scenario: func(m QueryShortChanIDs) bool {
return mainScenario(&m)
},
},
{
msgType: MsgReplyShortChanIDsEnd,
scenario: func(m ReplyShortChanIDsEnd) bool {
return mainScenario(&m)
},
},
{
msgType: MsgQueryChannelRange,
scenario: func(m QueryChannelRange) bool {
return mainScenario(&m)
},
},
{
msgType: MsgReplyChannelRange,
scenario: func(m ReplyChannelRange) bool {
return mainScenario(&m)
},
},
} }
for _, test := range tests { for _, test := range tests {
var config *quick.Config var config *quick.Config
@ -726,3 +808,7 @@ func TestLightningWireProtocol(t *testing.T) {
} }
} }
func init() {
rand.Seed(time.Now().Unix())
}

@ -49,6 +49,11 @@ const (
MsgNodeAnnouncement = 257 MsgNodeAnnouncement = 257
MsgChannelUpdate = 258 MsgChannelUpdate = 258
MsgAnnounceSignatures = 259 MsgAnnounceSignatures = 259
MsgQueryShortChanIDs = 261
MsgReplyShortChanIDsEnd = 262
MsgQueryChannelRange = 263
MsgReplyChannelRange = 264
MsgGossipTimestampRange = 265
) )
// String return the string representation of message type. // String return the string representation of message type.
@ -100,6 +105,16 @@ func (t MessageType) String() string {
return "Pong" return "Pong"
case MsgUpdateFee: case MsgUpdateFee:
return "UpdateFee" return "UpdateFee"
case MsgQueryShortChanIDs:
return "QueryShortChanIDs"
case MsgReplyShortChanIDsEnd:
return "ReplyShortChanIDsEnd"
case MsgQueryChannelRange:
return "QueryChannelRange"
case MsgReplyChannelRange:
return "ReplyChannelRange"
case MsgGossipTimestampRange:
return "GossipTimestampRange"
default: default:
return "<unknown>" return "<unknown>"
} }
@ -191,6 +206,16 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
msg = &AnnounceSignatures{} msg = &AnnounceSignatures{}
case MsgPong: case MsgPong:
msg = &Pong{} msg = &Pong{}
case MsgQueryShortChanIDs:
msg = &QueryShortChanIDs{}
case MsgReplyShortChanIDsEnd:
msg = &ReplyShortChanIDsEnd{}
case MsgQueryChannelRange:
msg = &QueryChannelRange{}
case MsgReplyChannelRange:
msg = &ReplyChannelRange{}
case MsgGossipTimestampRange:
msg = &GossipTimestampRange{}
default: default:
return nil, &UnknownMessage{msgType} return nil, &UnknownMessage{msgType}
} }