htlcswitch/mock: adds Encode/Decode to mock obfuscator

This commit is contained in:
Conner Fromknecht 2017-12-10 15:38:17 -08:00
parent 3ae7772ecb
commit 3048dfd4be
No known key found for this signature in database
GPG Key ID: 39DE78FBE6ACB0EF

@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
"io/ioutil"
"sync"
"testing"
"time"
@ -120,25 +121,45 @@ type mockServer struct {
var _ Peer = (*mockServer)(nil)
func newMockServer(t testing.TB, name string) *mockServer {
func initSwitchWithDB(db *channeldb.DB) (*Switch, error) {
if db == nil {
tempPath, err := ioutil.TempDir("", "switchdb")
if err != nil {
return nil, err
}
db, err = channeldb.Open(tempPath)
if err != nil {
return nil, err
}
}
return New(Config{
DB: db,
SwitchPackager: channeldb.NewSwitchPackager(),
})
}
func newMockServer(t testing.TB, name string, db *channeldb.DB) (*mockServer, error) {
var id [33]byte
h := sha256.Sum256([]byte(name))
copy(id[:], h[:])
return &mockServer{
t: t,
id: id,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: newMockRegistry(),
htlcSwitch: New(Config{
FwdingLog: &mockForwardingLog{
events: make(map[time.Time]channeldb.ForwardingEvent),
},
}),
interceptorFuncs: make([]messageInterceptor, 0),
htlcSwitch, err := initSwitchWithDB(db)
if err != nil {
return nil, err
}
return &mockServer{
t: t,
id: id,
name: name,
messages: make(chan lnwire.Message, 3000),
quit: make(chan struct{}),
registry: newMockRegistry(),
htlcSwitch: htlcSwitch,
interceptorFuncs: make([]messageInterceptor, 0),
}, nil
}
func (s *mockServer) Start() error {
@ -196,10 +217,6 @@ type mockHopIterator struct {
hops []ForwardingInfo
}
func (r *mockHopIterator) OnionPacket() *sphinx.OnionPacket {
return nil
}
func newMockHopIterator(hops ...ForwardingInfo) HopIterator {
return &mockHopIterator{hops: hops}
}
@ -261,7 +278,8 @@ type mockObfuscator struct {
ogPacket *sphinx.OnionPacket
}
func newMockObfuscator() ErrorEncrypter {
// NewMockObfuscator initializes a dummy mockObfuscator used for testing.
func NewMockObfuscator() ErrorEncrypter {
return &mockObfuscator{}
}
@ -512,6 +530,10 @@ type mockChannelLink struct {
peer Peer
startMailBox bool
mailBox MailBox
packets chan *htlcPacket
eligible bool
@ -519,6 +541,39 @@ type mockChannelLink struct {
htlcID uint64
}
// completeCircuit is a helper method for adding the finalized payment circuit
// to the switch's circuit map. In testing, this should be executed after
// receiving an htlc from the downstream packets channel.
func (f *mockChannelLink) completeCircuit(pkt *htlcPacket) error {
switch htlc := pkt.htlc.(type) {
case *lnwire.UpdateAddHTLC:
pkt.outgoingChanID = f.shortChanID
pkt.outgoingHTLCID = f.htlcID
htlc.ID = f.htlcID
keystone := Keystone{pkt.inKey(), pkt.outKey()}
if err := f.htlcSwitch.openCircuits(keystone); err != nil {
return err
}
f.htlcID++
case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC:
err := f.htlcSwitch.teardownCircuit(pkt)
if err != nil {
return err
}
}
f.mailBox.AckPacket(pkt.inKey())
return nil
}
func (f *mockChannelLink) deleteCircuit(pkt *htlcPacket) error {
return f.htlcSwitch.deleteCircuits(pkt.inKey())
}
func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
shortChanID lnwire.ShortChannelID, peer Peer, eligible bool,
) *mockChannelLink {
@ -527,27 +582,14 @@ func newMockChannelLink(htlcSwitch *Switch, chanID lnwire.ChannelID,
htlcSwitch: htlcSwitch,
chanID: chanID,
shortChanID: shortChanID,
packets: make(chan *htlcPacket, 1),
peer: peer,
eligible: eligible,
}
}
func (f *mockChannelLink) HandleSwitchPacket(packet *htlcPacket) {
switch htlc := packet.htlc.(type) {
case *lnwire.UpdateAddHTLC:
f.htlcSwitch.addCircuit(&PaymentCircuit{
PaymentHash: htlc.PaymentHash,
IncomingChanID: packet.incomingChanID,
IncomingHTLCID: packet.incomingHTLCID,
OutgoingChanID: f.shortChanID,
OutgoingHTLCID: f.htlcID,
ErrorEncrypter: packet.obfuscator,
})
f.htlcID++
}
f.packets <- packet
func (f *mockChannelLink) HandleSwitchPacket(pkt *htlcPacket) error {
f.mailBox.AddPacket(pkt)
return nil
}
func (f *mockChannelLink) HandleChannelUpdate(lnwire.Message) {
@ -560,12 +602,22 @@ func (f *mockChannelLink) Stats() (uint64, lnwire.MilliSatoshi, lnwire.MilliSato
return 0, 0, 0
}
func (f *mockChannelLink) AttachMailBox(mailBox MailBox) {
f.mailBox = mailBox
f.packets = mailBox.PacketOutBox()
}
func (f *mockChannelLink) Start() error {
f.mailBox.ResetMessages()
f.mailBox.ResetPackets()
return nil
}
func (f *mockChannelLink) ChanID() lnwire.ChannelID { return f.chanID }
func (f *mockChannelLink) ShortChanID() lnwire.ShortChannelID { return f.shortChanID }
func (f *mockChannelLink) UpdateShortChanID(sid lnwire.ShortChannelID) { f.shortChanID = sid }
func (f *mockChannelLink) Bandwidth() lnwire.MilliSatoshi { return 99999999 }
func (f *mockChannelLink) Peer() Peer { return f.peer }
func (f *mockChannelLink) Start() error { return nil }
func (f *mockChannelLink) Stop() {}
func (f *mockChannelLink) EligibleToForward() bool { return f.eligible }
@ -603,6 +655,10 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {
return fmt.Errorf("can't find mock invoice: %x", rhash[:])
}
if invoice.Terms.Settled {
return nil
}
invoice.Terms.Settled = true
i.invoices[rhash] = invoice