htlcswitch.test: add message interceptor handler
Add message interceptor which checks the order and may skip the messages which were denoted to be skipeed.
This commit is contained in:
parent
f29b4f60e4
commit
1eb906bcfb
@ -9,8 +9,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
"math"
|
"math"
|
||||||
@ -60,23 +58,92 @@ func messageToString(msg lnwire.Message) string {
|
|||||||
return spew.Sdump(msg)
|
return spew.Sdump(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expectedMessage struct hols the message which travels from one peer to
|
||||||
|
// another, and additional information like, should this message we skipped
|
||||||
|
// for handling.
|
||||||
|
type expectedMessage struct {
|
||||||
|
from string
|
||||||
|
to string
|
||||||
|
message lnwire.Message
|
||||||
|
skip bool
|
||||||
|
}
|
||||||
|
|
||||||
// createLogFunc is a helper function which returns the function which will be
|
// createLogFunc is a helper function which returns the function which will be
|
||||||
// used for logging message are received from another peer.
|
// used for logging message are received from another peer.
|
||||||
func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor {
|
func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor {
|
||||||
return func(m lnwire.Message) {
|
return func(m lnwire.Message) (bool, error) {
|
||||||
if getChanID(m) == channelID {
|
chanID, err := getChanID(m)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if chanID == channelID {
|
||||||
// Skip logging of extend revocation window messages.
|
// Skip logging of extend revocation window messages.
|
||||||
switch m := m.(type) {
|
switch m := m.(type) {
|
||||||
case *lnwire.RevokeAndAck:
|
case *lnwire.RevokeAndAck:
|
||||||
var zeroHash chainhash.Hash
|
var zeroHash chainhash.Hash
|
||||||
if bytes.Equal(zeroHash[:], m.Revocation[:]) {
|
if bytes.Equal(zeroHash[:], m.Revocation[:]) {
|
||||||
return
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("---------------------- \n %v received: "+
|
fmt.Printf("---------------------- \n %v received: "+
|
||||||
"%v", name, messageToString(m))
|
"%v", name, messageToString(m))
|
||||||
}
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// createInterceptorFunc creates the function by the given set of messages
|
||||||
|
// which, checks the order of the messages and skip the ones which were
|
||||||
|
// indicated to be intercepted.
|
||||||
|
func createInterceptorFunc(peer string, messages []expectedMessage,
|
||||||
|
chanID lnwire.ChannelID, debug bool) messageInterceptor {
|
||||||
|
|
||||||
|
// Filter message which should be received with given peer name.
|
||||||
|
var expectToReceive []expectedMessage
|
||||||
|
for _, message := range messages {
|
||||||
|
if message.to == peer {
|
||||||
|
expectToReceive = append(expectToReceive, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return function which checks the message order and skip the
|
||||||
|
// messages.
|
||||||
|
return func(m lnwire.Message) (bool, error) {
|
||||||
|
messageChanID, err := getChanID(m)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if messageChanID == chanID {
|
||||||
|
if len(expectToReceive) == 0 {
|
||||||
|
return false, errors.Errorf("received unexpected message out "+
|
||||||
|
"of range: %v", m.MsgType())
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := expectToReceive[0]
|
||||||
|
expectToReceive = expectToReceive[1:]
|
||||||
|
|
||||||
|
if expectedMessage.message.MsgType() != m.MsgType() {
|
||||||
|
return false, errors.Errorf("%v received wrong message: \n"+
|
||||||
|
"real: %v\nexpected: %v", peer, m.MsgType(),
|
||||||
|
expectedMessage.message.MsgType())
|
||||||
|
}
|
||||||
|
|
||||||
|
if debug {
|
||||||
|
if expectedMessage.skip {
|
||||||
|
fmt.Printf("'%v' skiped the received message: %v \n",
|
||||||
|
peer, m.MsgType())
|
||||||
|
} else {
|
||||||
|
fmt.Printf("'%v' received message: %v \n", peer,
|
||||||
|
m.MsgType())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return expectedMessage.skip, nil
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -101,11 +168,11 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
|
|||||||
debug := false
|
debug := false
|
||||||
if debug {
|
if debug {
|
||||||
// Log message that alice receives.
|
// Log message that alice receives.
|
||||||
n.aliceServer.record(createLogFunc("alice",
|
n.aliceServer.intersect(createLogFunc("alice",
|
||||||
n.aliceChannelLink.ChanID()))
|
n.aliceChannelLink.ChanID()))
|
||||||
|
|
||||||
// Log message that bob receives.
|
// Log message that bob receives.
|
||||||
n.bobServer.record(createLogFunc("bob",
|
n.bobServer.intersect(createLogFunc("bob",
|
||||||
n.firstBobChannelLink.ChanID()))
|
n.firstBobChannelLink.ChanID()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,11 +235,11 @@ func TestChannelLinkBidirectionalOneHopPayments(t *testing.T) {
|
|||||||
debug := false
|
debug := false
|
||||||
if debug {
|
if debug {
|
||||||
// Log message that alice receives.
|
// Log message that alice receives.
|
||||||
n.aliceServer.record(createLogFunc("alice",
|
n.aliceServer.intersect(createLogFunc("alice",
|
||||||
n.aliceChannelLink.ChanID()))
|
n.aliceChannelLink.ChanID()))
|
||||||
|
|
||||||
// Log message that bob receives.
|
// Log message that bob receives.
|
||||||
n.bobServer.record(createLogFunc("bob",
|
n.bobServer.intersect(createLogFunc("bob",
|
||||||
n.firstBobChannelLink.ChanID()))
|
n.firstBobChannelLink.ChanID()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -292,19 +359,19 @@ func TestChannelLinkMultiHopPayment(t *testing.T) {
|
|||||||
debug := false
|
debug := false
|
||||||
if debug {
|
if debug {
|
||||||
// Log messages that alice receives from bob.
|
// Log messages that alice receives from bob.
|
||||||
n.aliceServer.record(createLogFunc("[alice]<-bob<-carol: ",
|
n.aliceServer.intersect(createLogFunc("[alice]<-bob<-carol: ",
|
||||||
n.aliceChannelLink.ChanID()))
|
n.aliceChannelLink.ChanID()))
|
||||||
|
|
||||||
// Log messages that bob receives from alice.
|
// Log messages that bob receives from alice.
|
||||||
n.bobServer.record(createLogFunc("alice->[bob]->carol: ",
|
n.bobServer.intersect(createLogFunc("alice->[bob]->carol: ",
|
||||||
n.firstBobChannelLink.ChanID()))
|
n.firstBobChannelLink.ChanID()))
|
||||||
|
|
||||||
// Log messages that bob receives from carol.
|
// Log messages that bob receives from carol.
|
||||||
n.bobServer.record(createLogFunc("alice<-[bob]<-carol: ",
|
n.bobServer.intersect(createLogFunc("alice<-[bob]<-carol: ",
|
||||||
n.secondBobChannelLink.ChanID()))
|
n.secondBobChannelLink.ChanID()))
|
||||||
|
|
||||||
// Log messages that carol receives from bob.
|
// Log messages that carol receives from bob.
|
||||||
n.carolServer.record(createLogFunc("alice->bob->[carol]",
|
n.carolServer.intersect(createLogFunc("alice->bob->[carol]",
|
||||||
n.carolChannelLink.ChanID()))
|
n.carolChannelLink.ChanID()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1105,70 +1172,40 @@ func TestChannelLinkSingleHopMessageOrdering(t *testing.T) {
|
|||||||
testStartingHeight,
|
testStartingHeight,
|
||||||
)
|
)
|
||||||
|
|
||||||
chanPoint := n.aliceChannelLink.ChanID()
|
chanID := n.aliceChannelLink.ChanID()
|
||||||
|
|
||||||
// The order in which Alice receives wire messages.
|
messages := []expectedMessage{
|
||||||
var aliceOrder []lnwire.Message
|
{"alice", "bob", &lnwire.UpdateAddHTLC{}, false},
|
||||||
aliceOrder = append(aliceOrder, []lnwire.Message{
|
{"alice", "bob", &lnwire.CommitSig{}, false},
|
||||||
&lnwire.RevokeAndAck{},
|
{"bob", "alice", &lnwire.RevokeAndAck{}, false},
|
||||||
&lnwire.CommitSig{},
|
{"bob", "alice", &lnwire.CommitSig{}, false},
|
||||||
&lnwire.UpdateFufillHTLC{},
|
{"alice", "bob", &lnwire.RevokeAndAck{}, false},
|
||||||
&lnwire.CommitSig{},
|
|
||||||
&lnwire.RevokeAndAck{},
|
|
||||||
}...)
|
|
||||||
|
|
||||||
// The order in which Bob receives wire messages.
|
{"bob", "alice", &lnwire.UpdateFufillHTLC{}, false},
|
||||||
var bobOrder []lnwire.Message
|
{"bob", "alice", &lnwire.CommitSig{}, false},
|
||||||
bobOrder = append(bobOrder, []lnwire.Message{
|
{"alice", "bob", &lnwire.RevokeAndAck{}, false},
|
||||||
&lnwire.UpdateAddHTLC{},
|
{"alice", "bob", &lnwire.CommitSig{}, false},
|
||||||
&lnwire.CommitSig{},
|
{"bob", "alice", &lnwire.RevokeAndAck{}, false},
|
||||||
&lnwire.RevokeAndAck{},
|
}
|
||||||
&lnwire.RevokeAndAck{},
|
|
||||||
&lnwire.CommitSig{},
|
|
||||||
}...)
|
|
||||||
|
|
||||||
debug := false
|
debug := false
|
||||||
if debug {
|
if debug {
|
||||||
// Log message that alice receives.
|
// Log message that alice receives.
|
||||||
n.aliceServer.record(createLogFunc("alice",
|
n.aliceServer.intersect(createLogFunc("alice",
|
||||||
n.aliceChannelLink.ChanID()))
|
n.aliceChannelLink.ChanID()))
|
||||||
|
|
||||||
// Log message that bob receives.
|
// Log message that bob receives.
|
||||||
n.bobServer.record(createLogFunc("bob",
|
n.bobServer.intersect(createLogFunc("bob",
|
||||||
n.firstBobChannelLink.ChanID()))
|
n.firstBobChannelLink.ChanID()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that alice receives messages in right order.
|
// Check that alice receives messages in right order.
|
||||||
n.aliceServer.record(func(m lnwire.Message) {
|
n.aliceServer.intersect(createInterceptorFunc("alice", messages, chanID,
|
||||||
if getChanID(m) == chanPoint {
|
false))
|
||||||
if len(aliceOrder) == 0 {
|
|
||||||
t.Fatal("redundant messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflect.TypeOf(aliceOrder[0]) != reflect.TypeOf(m) {
|
|
||||||
t.Fatalf("alice received wrong message: \n"+
|
|
||||||
"real: %v\n expected: %v", m.MsgType(),
|
|
||||||
aliceOrder[0].MsgType())
|
|
||||||
}
|
|
||||||
aliceOrder = aliceOrder[1:]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Check that bob receives messages in right order.
|
// Check that bob receives messages in right order.
|
||||||
n.bobServer.record(func(m lnwire.Message) {
|
n.bobServer.intersect(createInterceptorFunc("bob", messages, chanID,
|
||||||
if getChanID(m) == chanPoint {
|
false))
|
||||||
if len(bobOrder) == 0 {
|
|
||||||
t.Fatal("redundant messages")
|
|
||||||
}
|
|
||||||
|
|
||||||
if reflect.TypeOf(bobOrder[0]) != reflect.TypeOf(m) {
|
|
||||||
t.Fatalf("bob received wrong message: \n"+
|
|
||||||
"real: %v\n expected: %v", m.MsgType(),
|
|
||||||
bobOrder[0].MsgType())
|
|
||||||
}
|
|
||||||
bobOrder = bobOrder[1:]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := n.start(); err != nil {
|
if err := n.start(); err != nil {
|
||||||
t.Fatalf("unable to start three hop network: %v", err)
|
t.Fatalf("unable to start three hop network: %v", err)
|
||||||
|
@ -5,13 +5,14 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"io"
|
"io"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/btcsuite/fastsha256"
|
"github.com/btcsuite/fastsha256"
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||||
@ -39,8 +40,8 @@ type mockServer struct {
|
|||||||
id [33]byte
|
id [33]byte
|
||||||
htlcSwitch *Switch
|
htlcSwitch *Switch
|
||||||
|
|
||||||
registry *mockInvoiceRegistry
|
registry *mockInvoiceRegistry
|
||||||
recordFuncs []func(lnwire.Message)
|
interceptorFuncs []messageInterceptor
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Peer = (*mockServer)(nil)
|
var _ Peer = (*mockServer)(nil)
|
||||||
@ -51,14 +52,14 @@ func newMockServer(t *testing.T, name string) *mockServer {
|
|||||||
copy(id[:], h[:])
|
copy(id[:], h[:])
|
||||||
|
|
||||||
return &mockServer{
|
return &mockServer{
|
||||||
t: t,
|
t: t,
|
||||||
id: id,
|
id: id,
|
||||||
name: name,
|
name: name,
|
||||||
messages: make(chan lnwire.Message, 3000),
|
messages: make(chan lnwire.Message, 3000),
|
||||||
quit: make(chan bool),
|
quit: make(chan bool),
|
||||||
registry: newMockRegistry(),
|
registry: newMockRegistry(),
|
||||||
htlcSwitch: New(Config{}),
|
htlcSwitch: New(Config{}),
|
||||||
recordFuncs: make([]func(lnwire.Message), 0),
|
interceptorFuncs: make([]messageInterceptor, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -76,8 +77,20 @@ func (s *mockServer) Start() error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-s.messages:
|
case msg := <-s.messages:
|
||||||
for _, f := range s.recordFuncs {
|
var shouldSkip bool
|
||||||
f(msg)
|
|
||||||
|
for _, interceptor := range s.interceptorFuncs {
|
||||||
|
skip, err := interceptor(msg)
|
||||||
|
if err != nil {
|
||||||
|
s.errChan <- errors.Errorf("%v: error in the "+
|
||||||
|
"interceptor: %v", s.name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
shouldSkip = shouldSkip || skip
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldSkip {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.readHandler(msg); err != nil {
|
if err := s.readHandler(msg); err != nil {
|
||||||
@ -245,13 +258,13 @@ func (f *ForwardingInfo) decode(r io.Reader) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// messageInterceptor is function that handles the incoming peer messages and
|
// messageInterceptor is function that handles the incoming peer messages and
|
||||||
// may decide should we handle it or not.
|
// may decide should the peer skip the message or not.
|
||||||
type messageInterceptor func(m lnwire.Message)
|
type messageInterceptor func(m lnwire.Message) (bool, error)
|
||||||
|
|
||||||
// Record is used to set the function which will be triggered when new
|
// Record is used to set the function which will be triggered when new
|
||||||
// lnwire message was received.
|
// lnwire message was received.
|
||||||
func (s *mockServer) record(f messageInterceptor) {
|
func (s *mockServer) intersect(f messageInterceptor) {
|
||||||
s.recordFuncs = append(s.recordFuncs, f)
|
s.interceptorFuncs = append(s.interceptorFuncs, f)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *mockServer) SendMessage(message lnwire.Message) error {
|
func (s *mockServer) SendMessage(message lnwire.Message) error {
|
||||||
@ -297,11 +310,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error {
|
|||||||
// the server when handler stacked (server unavailable)
|
// the server when handler stacked (server unavailable)
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
|
||||||
done <- struct{}{}
|
|
||||||
}()
|
|
||||||
|
|
||||||
link.HandleChannelUpdate(message)
|
link.HandleChannelUpdate(message)
|
||||||
|
done <- struct{}{}
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
|
@ -253,22 +253,26 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getChanID retrieves the channel point from nwire message.
|
// getChanID retrieves the channel point from nwire message.
|
||||||
func getChanID(msg lnwire.Message) lnwire.ChannelID {
|
func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) {
|
||||||
var point lnwire.ChannelID
|
var chanID lnwire.ChannelID
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *lnwire.UpdateAddHTLC:
|
case *lnwire.UpdateAddHTLC:
|
||||||
point = msg.ChanID
|
chanID = msg.ChanID
|
||||||
case *lnwire.UpdateFufillHTLC:
|
case *lnwire.UpdateFufillHTLC:
|
||||||
point = msg.ChanID
|
chanID = msg.ChanID
|
||||||
case *lnwire.UpdateFailHTLC:
|
case *lnwire.UpdateFailHTLC:
|
||||||
point = msg.ChanID
|
chanID = msg.ChanID
|
||||||
case *lnwire.RevokeAndAck:
|
case *lnwire.RevokeAndAck:
|
||||||
point = msg.ChanID
|
chanID = msg.ChanID
|
||||||
case *lnwire.CommitSig:
|
case *lnwire.CommitSig:
|
||||||
point = msg.ChanID
|
chanID = msg.ChanID
|
||||||
|
case *lnwire.ChannelReestablish:
|
||||||
|
chanID = msg.ChanID
|
||||||
|
default:
|
||||||
|
return chanID, errors.New("unknown type")
|
||||||
}
|
}
|
||||||
|
|
||||||
return point
|
return chanID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// generatePayment generates the htlc add request by given path blob and
|
// generatePayment generates the htlc add request by given path blob and
|
||||||
|
Loading…
Reference in New Issue
Block a user