diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index e0f3c07e..ea4ba96b 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -22,6 +22,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/htlcswitch/hodl" + "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/roasbeef/btcd/btcec" @@ -1396,14 +1397,14 @@ type mockPeer struct { quit chan struct{} } -var _ Peer = (*mockPeer)(nil) +var _ lnpeer.Peer = (*mockPeer)(nil) -func (m *mockPeer) SendMessage(msg lnwire.Message, sync bool) error { +func (m *mockPeer) SendMessage(sync bool, msgs ...lnwire.Message) error { if m.disconnected { return fmt.Errorf("disconnected") } select { - case m.sentMsgs <- msg: + case m.sentMsgs <- msgs[0]: case <-m.quit: return fmt.Errorf("mockPeer shutting down") } @@ -1415,8 +1416,11 @@ func (m *mockPeer) WipeChannel(*wire.OutPoint) error { func (m *mockPeer) PubKey() [33]byte { return [33]byte{} } +func (m *mockPeer) IdentityKey() *btcec.PublicKey { + return nil +} -var _ Peer = (*mockPeer)(nil) +var _ lnpeer.Peer = (*mockPeer)(nil) func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) ( ChannelLink, *lnwallet.LightningChannel, chan time.Time, func() error, diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 0ed96408..db1dcc65 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" + "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/roasbeef/btcd/btcec" @@ -119,7 +120,7 @@ type mockServer struct { interceptorFuncs []messageInterceptor } -var _ Peer = (*mockServer)(nil) +var _ lnpeer.Peer = (*mockServer)(nil) func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { if db == nil { @@ -450,12 +451,14 @@ func (s *mockServer) intersect(f messageInterceptor) { s.interceptorFuncs = append(s.interceptorFuncs, f) } -func (s *mockServer) SendMessage(message lnwire.Message, sync bool) error { +func (s *mockServer) SendMessage(sync bool, msgs ...lnwire.Message) error { - select { - case s.messages <- message: - case <-s.quit: - return errors.New("server is stopped") + for _, msg := range msgs { + select { + case s.messages <- msg: + case <-s.quit: + return errors.New("server is stopped") + } } return nil @@ -506,6 +509,11 @@ func (s *mockServer) PubKey() [33]byte { return s.id } +func (s *mockServer) IdentityKey() *btcec.PublicKey { + pubkey, _ := btcec.ParsePubKey(s.id[:], btcec.S256()) + return pubkey +} + func (s *mockServer) WipeChannel(*wire.OutPoint) error { return nil } @@ -532,7 +540,7 @@ type mockChannelLink struct { chanID lnwire.ChannelID - peer Peer + peer lnpeer.Peer startMailBox bool @@ -579,7 +587,7 @@ func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error { } func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, - shortChanID lnwire.ShortChannelID, peer Peer, eligible bool, + shortChanID lnwire.ShortChannelID, peer lnpeer.Peer, eligible bool, ) *mockChannelLink { return &mockChannelLink{ @@ -624,7 +632,7 @@ func (f *mockChannelLink) Start() error { func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID } func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID } func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { return 99999999 } -func (f *mockChannelLink) Peer() Peer { return f.peer } +func (f *mockChannelLink) Peer() lnpeer.Peer { return f.peer } func (f *mockChannelLink) Stop() {} func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } func (f *mockChannelLink) setLiveShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index 5ff5a1d1..da48a412 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" @@ -672,7 +673,7 @@ func (r *paymentResponse) Wait(d time.Duration) (chainhash.Hash, error) { // * from Alice to Bob // * from Alice to Carol through the Bob // * from Alice to some another peer through the Bob -func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer Peer, +func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer lnpeer.Peer, firstHopPub [33]byte, hops []ForwardingInfo, invoiceAmt, htlcAmt lnwire.MilliSatoshi, timelock uint32) *paymentResponse {