diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index a8f0b0fe..ca16c02f 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -4,6 +4,7 @@ import ( "crypto/sha256" "encoding/binary" "fmt" + "io/ioutil" "sync" "testing" "time" @@ -120,25 +121,45 @@ type mockServer struct { var _ Peer = (*mockServer)(nil) -func newMockServer(t testing.TB, name string) *mockServer { +func initSwitchWithDB(db *channeldb.DB) (*Switch, error) { + if db == nil { + tempPath, err := ioutil.TempDir("", "switchdb") + if err != nil { + return nil, err + } + + db, err = channeldb.Open(tempPath) + if err != nil { + return nil, err + } + } + + return New(Config{ + DB: db, + SwitchPackager: channeldb.NewSwitchPackager(), + }) +} + +func newMockServer(t testing.TB, name string, db *channeldb.DB) (*mockServer, error) { var id [33]byte h := sha256.Sum256([]byte(name)) copy(id[:], h[:]) - return &mockServer{ - t: t, - id: id, - name: name, - messages: make(chan lnwire.Message, 3000), - quit: make(chan struct{}), - registry: newMockRegistry(), - htlcSwitch: New(Config{ - FwdingLog: &mockForwardingLog{ - events: make(map[time.Time]channeldb.ForwardingEvent), - }, - }), - interceptorFuncs: make([]messageInterceptor, 0), + htlcSwitch, err := initSwitchWithDB(db) + if err != nil { + return nil, err } + + return &mockServer{ + t: t, + id: id, + name: name, + messages: make(chan lnwire.Message, 3000), + quit: make(chan struct{}), + registry: newMockRegistry(), + htlcSwitch: htlcSwitch, + interceptorFuncs: make([]messageInterceptor, 0), + }, nil } func (s *mockServer) Start() error { @@ -196,10 +217,6 @@ type mockHopIterator struct { hops []ForwardingInfo } -func (r *mockHopIterator) OnionPacket() *sphinx.OnionPacket { - return nil -} - func newMockHopIterator(hops ...ForwardingInfo) HopIterator { return &mockHopIterator{hops: hops} } @@ -261,7 +278,8 @@ type mockObfuscator struct { ogPacket *sphinx.OnionPacket } -func newMockObfuscator() ErrorEncrypter { +// NewMockObfuscator initializes a dummy mockObfuscator used for testing. +func NewMockObfuscator() ErrorEncrypter { return &mockObfuscator{} } @@ -512,6 +530,10 @@ type mockChannelLink struct { peer Peer + startMailBox bool + + mailBox MailBox + packets chan *htlcPacket eligible bool @@ -519,6 +541,39 @@ type mockChannelLink struct { htlcID uint64 } +// completeCircuit is a helper method for adding the finalized payment circuit +// to the switch's circuit map. In testing, this should be executed after +// receiving an htlc from the downstream packets channel. +func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error { + switch htlc := pkt.htlc.(type) { + case *lnwire.UpdateAddHTLC: + pkt.outgoingChanID = f.shortChanID + pkt.outgoingHTLCID = f.htlcID + htlc.ID = f.htlcID + + keystone := Keystone{pkt.inKey(), pkt.outKey()} + if err := f.htlcSwitch.openCircuits(keystone); err != nil { + return err + } + + f.htlcID++ + + case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC: + err := f.htlcSwitch.teardownCircuit(pkt) + if err != nil { + return err + } + } + + f.mailBox.AckPacket(pkt.inKey()) + + return nil +} + +func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error { + return f.htlcSwitch.deleteCircuits(pkt.inKey()) +} + func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, shortChanID lnwire.ShortChannelID, peer Peer, eligible bool, ) *mockChannelLink { @@ -527,27 +582,14 @@ func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID, htlcSwitch: htlcSwitch, chanID: chanID, shortChanID: shortChanID, - packets: make(chan *htlcPacket, 1), peer: peer, eligible: eligible, } } -func (f *mockChannelLink) HandleSwitchPacket(packet *htlcPacket) { - switch htlc := packet.htlc.(type) { - case *lnwire.UpdateAddHTLC: - f.htlcSwitch.addCircuit(&PaymentCircuit{ - PaymentHash: htlc.PaymentHash, - IncomingChanID: packet.incomingChanID, - IncomingHTLCID: packet.incomingHTLCID, - OutgoingChanID: f.shortChanID, - OutgoingHTLCID: f.htlcID, - ErrorEncrypter: packet.obfuscator, - }) - f.htlcID++ - } - - f.packets <- packet +func (f *mockChannelLink) HandleSwitchPacket(pkt *htlcPacket) error { + f.mailBox.AddPacket(pkt) + return nil } func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) { @@ -560,12 +602,22 @@ func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSato return 0, 0, 0 } +func (f *mockChannelLink) AttachMailBox(mailBox MailBox) { + f.mailBox = mailBox + f.packets = mailBox.PacketOutBox() +} + +func (f *mockChannelLink) Start() error { + f.mailBox.ResetMessages() + f.mailBox.ResetPackets() + return nil +} + func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID } func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID } func (f *mockChannelLink) UpdateShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid } func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { return 99999999 } func (f *mockChannelLink) Peer() Peer { return f.peer } -func (f *mockChannelLink) Start() error { return nil } func (f *mockChannelLink) Stop() {} func (f *mockChannelLink) EligibleToForward() bool { return f.eligible } @@ -603,6 +655,10 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error { return fmt.Errorf("can't find mock invoice: %x", rhash[:]) } + if invoice.Terms.Settled { + return nil + } + invoice.Terms.Settled = true i.invoices[rhash] = invoice