htlcswitch.test: add message interceptor handler

Add message interceptor which checks the order and may skip the
messages which were denoted to be skipeed.
This commit is contained in:
Andrey Samokhvalov 2017-07-09 01:46:27 +03:00 committed by Olaoluwa Osuntokun
parent f29b4f60e4
commit 1eb906bcfb
3 changed files with 142 additions and 91 deletions

@ -9,8 +9,6 @@ import (
"testing" "testing"
"time" "time"
"reflect"
"io" "io"
"math" "math"
@ -60,23 +58,92 @@ func messageToString(msg lnwire.Message) string {
return spew.Sdump(msg) return spew.Sdump(msg)
} }
// expectedMessage struct hols the message which travels from one peer to
// another, and additional information like, should this message we skipped
// for handling.
type expectedMessage struct {
from string
to string
message lnwire.Message
skip bool
}
// createLogFunc is a helper function which returns the function which will be // createLogFunc is a helper function which returns the function which will be
// used for logging message are received from another peer. // used for logging message are received from another peer.
func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor { func createLogFunc(name string, channelID lnwire.ChannelID) messageInterceptor {
return func(m lnwire.Message) { return func(m lnwire.Message) (bool, error) {
if getChanID(m) == channelID { chanID, err := getChanID(m)
if err != nil {
return false, err
}
if chanID == channelID {
// Skip logging of extend revocation window messages. // Skip logging of extend revocation window messages.
switch m := m.(type) { switch m := m.(type) {
case *lnwire.RevokeAndAck: case *lnwire.RevokeAndAck:
var zeroHash chainhash.Hash var zeroHash chainhash.Hash
if bytes.Equal(zeroHash[:], m.Revocation[:]) { if bytes.Equal(zeroHash[:], m.Revocation[:]) {
return return false, nil
} }
} }
fmt.Printf("---------------------- \n %v received: "+ fmt.Printf("---------------------- \n %v received: "+
"%v", name, messageToString(m)) "%v", name, messageToString(m))
} }
return false, nil
}
}
// createInterceptorFunc creates the function by the given set of messages
// which, checks the order of the messages and skip the ones which were
// indicated to be intercepted.
func createInterceptorFunc(peer string, messages []expectedMessage,
chanID lnwire.ChannelID, debug bool) messageInterceptor {
// Filter message which should be received with given peer name.
var expectToReceive []expectedMessage
for _, message := range messages {
if message.to == peer {
expectToReceive = append(expectToReceive, message)
}
}
// Return function which checks the message order and skip the
// messages.
return func(m lnwire.Message) (bool, error) {
messageChanID, err := getChanID(m)
if err != nil {
return false, err
}
if messageChanID == chanID {
if len(expectToReceive) == 0 {
return false, errors.Errorf("received unexpected message out "+
"of range: %v", m.MsgType())
}
expectedMessage := expectToReceive[0]
expectToReceive = expectToReceive[1:]
if expectedMessage.message.MsgType() != m.MsgType() {
return false, errors.Errorf("%v received wrong message: \n"+
"real: %v\nexpected: %v", peer, m.MsgType(),
expectedMessage.message.MsgType())
}
if debug {
if expectedMessage.skip {
fmt.Printf("'%v' skiped the received message: %v \n",
peer, m.MsgType())
} else {
fmt.Printf("'%v' received message: %v \n", peer,
m.MsgType())
}
}
return expectedMessage.skip, nil
}
return false, nil
} }
} }
@ -101,11 +168,11 @@ func TestChannelLinkSingleHopPayment(t *testing.T) {
debug := false debug := false
if debug { if debug {
// Log message that alice receives. // Log message that alice receives.
n.aliceServer.record(createLogFunc("alice", n.aliceServer.intersect(createLogFunc("alice",
n.aliceChannelLink.ChanID())) n.aliceChannelLink.ChanID()))
// Log message that bob receives. // Log message that bob receives.
n.bobServer.record(createLogFunc("bob", n.bobServer.intersect(createLogFunc("bob",
n.firstBobChannelLink.ChanID())) n.firstBobChannelLink.ChanID()))
} }
@ -168,11 +235,11 @@ func TestChannelLinkBidirectionalOneHopPayments(t *testing.T) {
debug := false debug := false
if debug { if debug {
// Log message that alice receives. // Log message that alice receives.
n.aliceServer.record(createLogFunc("alice", n.aliceServer.intersect(createLogFunc("alice",
n.aliceChannelLink.ChanID())) n.aliceChannelLink.ChanID()))
// Log message that bob receives. // Log message that bob receives.
n.bobServer.record(createLogFunc("bob", n.bobServer.intersect(createLogFunc("bob",
n.firstBobChannelLink.ChanID())) n.firstBobChannelLink.ChanID()))
} }
@ -292,19 +359,19 @@ func TestChannelLinkMultiHopPayment(t *testing.T) {
debug := false debug := false
if debug { if debug {
// Log messages that alice receives from bob. // Log messages that alice receives from bob.
n.aliceServer.record(createLogFunc("[alice]<-bob<-carol: ", n.aliceServer.intersect(createLogFunc("[alice]<-bob<-carol: ",
n.aliceChannelLink.ChanID())) n.aliceChannelLink.ChanID()))
// Log messages that bob receives from alice. // Log messages that bob receives from alice.
n.bobServer.record(createLogFunc("alice->[bob]->carol: ", n.bobServer.intersect(createLogFunc("alice->[bob]->carol: ",
n.firstBobChannelLink.ChanID())) n.firstBobChannelLink.ChanID()))
// Log messages that bob receives from carol. // Log messages that bob receives from carol.
n.bobServer.record(createLogFunc("alice<-[bob]<-carol: ", n.bobServer.intersect(createLogFunc("alice<-[bob]<-carol: ",
n.secondBobChannelLink.ChanID())) n.secondBobChannelLink.ChanID()))
// Log messages that carol receives from bob. // Log messages that carol receives from bob.
n.carolServer.record(createLogFunc("alice->bob->[carol]", n.carolServer.intersect(createLogFunc("alice->bob->[carol]",
n.carolChannelLink.ChanID())) n.carolChannelLink.ChanID()))
} }
@ -1105,70 +1172,40 @@ func TestChannelLinkSingleHopMessageOrdering(t *testing.T) {
testStartingHeight, testStartingHeight,
) )
chanPoint := n.aliceChannelLink.ChanID() chanID := n.aliceChannelLink.ChanID()
// The order in which Alice receives wire messages. messages := []expectedMessage{
var aliceOrder []lnwire.Message {"alice", "bob", &lnwire.UpdateAddHTLC{}, false},
aliceOrder = append(aliceOrder, []lnwire.Message{ {"alice", "bob", &lnwire.CommitSig{}, false},
&lnwire.RevokeAndAck{}, {"bob", "alice", &lnwire.RevokeAndAck{}, false},
&lnwire.CommitSig{}, {"bob", "alice", &lnwire.CommitSig{}, false},
&lnwire.UpdateFufillHTLC{}, {"alice", "bob", &lnwire.RevokeAndAck{}, false},
&lnwire.CommitSig{},
&lnwire.RevokeAndAck{},
}...)
// The order in which Bob receives wire messages. {"bob", "alice", &lnwire.UpdateFufillHTLC{}, false},
var bobOrder []lnwire.Message {"bob", "alice", &lnwire.CommitSig{}, false},
bobOrder = append(bobOrder, []lnwire.Message{ {"alice", "bob", &lnwire.RevokeAndAck{}, false},
&lnwire.UpdateAddHTLC{}, {"alice", "bob", &lnwire.CommitSig{}, false},
&lnwire.CommitSig{}, {"bob", "alice", &lnwire.RevokeAndAck{}, false},
&lnwire.RevokeAndAck{}, }
&lnwire.RevokeAndAck{},
&lnwire.CommitSig{},
}...)
debug := false debug := false
if debug { if debug {
// Log message that alice receives. // Log message that alice receives.
n.aliceServer.record(createLogFunc("alice", n.aliceServer.intersect(createLogFunc("alice",
n.aliceChannelLink.ChanID())) n.aliceChannelLink.ChanID()))
// Log message that bob receives. // Log message that bob receives.
n.bobServer.record(createLogFunc("bob", n.bobServer.intersect(createLogFunc("bob",
n.firstBobChannelLink.ChanID())) n.firstBobChannelLink.ChanID()))
} }
// Check that alice receives messages in right order. // Check that alice receives messages in right order.
n.aliceServer.record(func(m lnwire.Message) { n.aliceServer.intersect(createInterceptorFunc("alice", messages, chanID,
if getChanID(m) == chanPoint { false))
if len(aliceOrder) == 0 {
t.Fatal("redundant messages")
}
if reflect.TypeOf(aliceOrder[0]) != reflect.TypeOf(m) {
t.Fatalf("alice received wrong message: \n"+
"real: %v\n expected: %v", m.MsgType(),
aliceOrder[0].MsgType())
}
aliceOrder = aliceOrder[1:]
}
})
// Check that bob receives messages in right order. // Check that bob receives messages in right order.
n.bobServer.record(func(m lnwire.Message) { n.bobServer.intersect(createInterceptorFunc("bob", messages, chanID,
if getChanID(m) == chanPoint { false))
if len(bobOrder) == 0 {
t.Fatal("redundant messages")
}
if reflect.TypeOf(bobOrder[0]) != reflect.TypeOf(m) {
t.Fatalf("bob received wrong message: \n"+
"real: %v\n expected: %v", m.MsgType(),
bobOrder[0].MsgType())
}
bobOrder = bobOrder[1:]
}
})
if err := n.start(); err != nil { if err := n.start(); err != nil {
t.Fatalf("unable to start three hop network: %v", err) t.Fatalf("unable to start three hop network: %v", err)

@ -5,13 +5,14 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"sync" "sync"
"testing"
"io" "io"
"sync/atomic" "sync/atomic"
"bytes" "bytes"
"testing"
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
@ -40,7 +41,7 @@ type mockServer struct {
htlcSwitch *Switch htlcSwitch *Switch
registry *mockInvoiceRegistry registry *mockInvoiceRegistry
recordFuncs []func(lnwire.Message) interceptorFuncs []messageInterceptor
} }
var _ Peer = (*mockServer)(nil) var _ Peer = (*mockServer)(nil)
@ -58,7 +59,7 @@ func newMockServer(t *testing.T, name string) *mockServer {
quit: make(chan bool), quit: make(chan bool),
registry: newMockRegistry(), registry: newMockRegistry(),
htlcSwitch: New(Config{}), htlcSwitch: New(Config{}),
recordFuncs: make([]func(lnwire.Message), 0), interceptorFuncs: make([]messageInterceptor, 0),
} }
} }
@ -76,8 +77,20 @@ func (s *mockServer) Start() error {
for { for {
select { select {
case msg := <-s.messages: case msg := <-s.messages:
for _, f := range s.recordFuncs { var shouldSkip bool
f(msg)
for _, interceptor := range s.interceptorFuncs {
skip, err := interceptor(msg)
if err != nil {
s.errChan <- errors.Errorf("%v: error in the "+
"interceptor: %v", s.name, err)
return
}
shouldSkip = shouldSkip || skip
}
if shouldSkip {
continue
} }
if err := s.readHandler(msg); err != nil { if err := s.readHandler(msg); err != nil {
@ -245,13 +258,13 @@ func (f *ForwardingInfo) decode(r io.Reader) error {
} }
// messageInterceptor is function that handles the incoming peer messages and // messageInterceptor is function that handles the incoming peer messages and
// may decide should we handle it or not. // may decide should the peer skip the message or not.
type messageInterceptor func(m lnwire.Message) type messageInterceptor func(m lnwire.Message) (bool, error)
// Record is used to set the function which will be triggered when new // Record is used to set the function which will be triggered when new
// lnwire message was received. // lnwire message was received.
func (s *mockServer) record(f messageInterceptor) { func (s *mockServer) intersect(f messageInterceptor) {
s.recordFuncs = append(s.recordFuncs, f) s.interceptorFuncs = append(s.interceptorFuncs, f)
} }
func (s *mockServer) SendMessage(message lnwire.Message) error { func (s *mockServer) SendMessage(message lnwire.Message) error {
@ -297,11 +310,8 @@ func (s *mockServer) readHandler(message lnwire.Message) error {
// the server when handler stacked (server unavailable) // the server when handler stacked (server unavailable)
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer func() {
done <- struct{}{}
}()
link.HandleChannelUpdate(message) link.HandleChannelUpdate(message)
done <- struct{}{}
}() }()
select { select {
case <-done: case <-done:

@ -253,22 +253,26 @@ func createTestChannel(alicePrivKey, bobPrivKey []byte,
} }
// getChanID retrieves the channel point from nwire message. // getChanID retrieves the channel point from nwire message.
func getChanID(msg lnwire.Message) lnwire.ChannelID { func getChanID(msg lnwire.Message) (lnwire.ChannelID, error) {
var point lnwire.ChannelID var chanID lnwire.ChannelID
switch msg := msg.(type) { switch msg := msg.(type) {
case *lnwire.UpdateAddHTLC: case *lnwire.UpdateAddHTLC:
point = msg.ChanID chanID = msg.ChanID
case *lnwire.UpdateFufillHTLC: case *lnwire.UpdateFufillHTLC:
point = msg.ChanID chanID = msg.ChanID
case *lnwire.UpdateFailHTLC: case *lnwire.UpdateFailHTLC:
point = msg.ChanID chanID = msg.ChanID
case *lnwire.RevokeAndAck: case *lnwire.RevokeAndAck:
point = msg.ChanID chanID = msg.ChanID
case *lnwire.CommitSig: case *lnwire.CommitSig:
point = msg.ChanID chanID = msg.ChanID
case *lnwire.ChannelReestablish:
chanID = msg.ChanID
default:
return chanID, errors.New("unknown type")
} }
return point return chanID, nil
} }
// generatePayment generates the htlc add request by given path blob and // generatePayment generates the htlc add request by given path blob and