lnd.xprv/routing/mock_test.go
Johan T. Halseth bee2380441
channeldb: rename PaymentAttemptInfo to HTLCAttemptInfo
To better distinguish payments from HTLCs, we rename the attempt info
struct to HTLCAttemptInfo. We also embed it into the HTLCAttempt struct,
to avoid having to duplicate this information.

The paymentID term is renamed to attemptID.
2020-03-09 11:43:26 +01:00

313 lines
6.2 KiB
Go

package routing
import (
"fmt"
"sync"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/zpay32"
)
type mockPaymentAttemptDispatcher struct {
onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
results map[uint64]*htlcswitch.PaymentResult
}
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
pid uint64,
_ *lnwire.UpdateAddHTLC) error {
if m.onPayment == nil {
return nil
}
if m.results == nil {
m.results = make(map[uint64]*htlcswitch.PaymentResult)
}
var result *htlcswitch.PaymentResult
preimage, err := m.onPayment(firstHop)
if err != nil {
rtErr, ok := err.(htlcswitch.ClearTextError)
if !ok {
return err
}
result = &htlcswitch.PaymentResult{
Error: rtErr,
}
} else {
result = &htlcswitch.PaymentResult{Preimage: preimage}
}
m.results[pid] = result
return nil
}
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
_ lntypes.Hash, _ htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) {
c := make(chan *htlcswitch.PaymentResult, 1)
res, ok := m.results[paymentID]
if !ok {
return nil, htlcswitch.ErrPaymentIDNotFound
}
c <- res
return c, nil
}
func (m *mockPaymentAttemptDispatcher) setPaymentResult(
f func(firstHop lnwire.ShortChannelID) ([32]byte, error)) {
m.onPayment = f
}
type mockPaymentSessionSource struct {
routes []*route.Route
}
var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
func (m *mockPaymentSessionSource) NewPaymentSession(routeHints [][]zpay32.HopHint,
target route.Vertex) (PaymentSession, error) {
return &mockPaymentSession{m.routes}, nil
}
func (m *mockPaymentSessionSource) NewPaymentSessionForRoute(
preBuiltRoute *route.Route) PaymentSession {
return nil
}
func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
return &mockPaymentSession{}
}
type mockMissionControl struct {
}
var _ MissionController = (*mockMissionControl)(nil)
func (m *mockMissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route,
failureSourceIdx *int, failure lnwire.FailureMessage) (
*channeldb.FailureReason, error) {
return nil, nil
}
func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
rt *route.Route) error {
return nil
}
func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
amt lnwire.MilliSatoshi) float64 {
return 0
}
type mockPaymentSession struct {
routes []*route.Route
}
var _ PaymentSession = (*mockPaymentSession)(nil)
func (m *mockPaymentSession) RequestRoute(payment *LightningPayment,
height uint32, finalCltvDelta uint16) (*route.Route, error) {
if len(m.routes) == 0 {
return nil, fmt.Errorf("no routes")
}
r := m.routes[0]
m.routes = m.routes[1:]
return r, nil
}
type mockPayer struct {
sendResult chan error
paymentResultErr chan error
paymentResult chan *htlcswitch.PaymentResult
quit chan struct{}
}
var _ PaymentAttemptDispatcher = (*mockPayer)(nil)
func (m *mockPayer) SendHTLC(_ lnwire.ShortChannelID,
paymentID uint64,
_ *lnwire.UpdateAddHTLC) error {
select {
case res := <-m.sendResult:
return res
case <-m.quit:
return fmt.Errorf("test quitting")
}
}
func (m *mockPayer) GetPaymentResult(paymentID uint64, _ lntypes.Hash,
_ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {
select {
case res := <-m.paymentResult:
resChan := make(chan *htlcswitch.PaymentResult, 1)
resChan <- res
return resChan, nil
case err := <-m.paymentResultErr:
return nil, err
case <-m.quit:
return nil, fmt.Errorf("test quitting")
}
}
type initArgs struct {
c *channeldb.PaymentCreationInfo
}
type registerArgs struct {
a *channeldb.HTLCAttemptInfo
}
type successArgs struct {
preimg lntypes.Preimage
}
type failArgs struct {
reason channeldb.FailureReason
}
type mockControlTower struct {
inflights map[lntypes.Hash]channeldb.InFlightPayment
successful map[lntypes.Hash]struct{}
init chan initArgs
register chan registerArgs
success chan successArgs
fail chan failArgs
fetchInFlight chan struct{}
sync.Mutex
}
var _ ControlTower = (*mockControlTower)(nil)
func makeMockControlTower() *mockControlTower {
return &mockControlTower{
inflights: make(map[lntypes.Hash]channeldb.InFlightPayment),
successful: make(map[lntypes.Hash]struct{}),
}
}
func (m *mockControlTower) InitPayment(phash lntypes.Hash,
c *channeldb.PaymentCreationInfo) error {
m.Lock()
defer m.Unlock()
if m.init != nil {
m.init <- initArgs{c}
}
if _, ok := m.successful[phash]; ok {
return fmt.Errorf("already successful")
}
_, ok := m.inflights[phash]
if ok {
return fmt.Errorf("in flight")
}
m.inflights[phash] = channeldb.InFlightPayment{
Info: c,
}
return nil
}
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.HTLCAttemptInfo) error {
m.Lock()
defer m.Unlock()
if m.register != nil {
m.register <- registerArgs{a}
}
p, ok := m.inflights[phash]
if !ok {
return fmt.Errorf("not in flight")
}
p.Attempt = a
m.inflights[phash] = p
return nil
}
func (m *mockControlTower) Success(phash lntypes.Hash,
preimg lntypes.Preimage) error {
m.Lock()
defer m.Unlock()
if m.success != nil {
m.success <- successArgs{preimg}
}
delete(m.inflights, phash)
m.successful[phash] = struct{}{}
return nil
}
func (m *mockControlTower) Fail(phash lntypes.Hash,
reason channeldb.FailureReason) error {
m.Lock()
defer m.Unlock()
if m.fail != nil {
m.fail <- failArgs{reason}
}
delete(m.inflights, phash)
return nil
}
func (m *mockControlTower) FetchInFlightPayments() (
[]*channeldb.InFlightPayment, error) {
m.Lock()
defer m.Unlock()
if m.fetchInFlight != nil {
m.fetchInFlight <- struct{}{}
}
var fl []*channeldb.InFlightPayment
for _, ifl := range m.inflights {
fl = append(fl, &ifl)
}
return fl, nil
}
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
bool, chan PaymentResult, error) {
return false, nil, errors.New("not implemented")
}