lnd.xprv/routing/mock_test.go
carla b5a2d75465
htlcswitch+routing: type check on ClearTextError
Update the type check used for checking local payment
failures to check on the ClearTextError interface rather
than on the ForwardingError type. This change prepares
for splitting payment errors up into Link and Forwarding
errors.
2020-01-14 15:07:42 +02:00

313 lines
6.2 KiB
Go

package routing
import (
"fmt"
"sync"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/zpay32"
)
type mockPaymentAttemptDispatcher struct {
onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
results map[uint64]*htlcswitch.PaymentResult
}
var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)
func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
pid uint64,
_ *lnwire.UpdateAddHTLC) error {
if m.onPayment == nil {
return nil
}
if m.results == nil {
m.results = make(map[uint64]*htlcswitch.PaymentResult)
}
var result *htlcswitch.PaymentResult
preimage, err := m.onPayment(firstHop)
if err != nil {
rtErr, ok := err.(htlcswitch.ClearTextError)
if !ok {
return err
}
result = &htlcswitch.PaymentResult{
Error: rtErr,
}
} else {
result = &htlcswitch.PaymentResult{Preimage: preimage}
}
m.results[pid] = result
return nil
}
func (m *mockPaymentAttemptDispatcher) GetPaymentResult(paymentID uint64,
_ lntypes.Hash, _ htlcswitch.ErrorDecrypter) (
<-chan *htlcswitch.PaymentResult, error) {
c := make(chan *htlcswitch.PaymentResult, 1)
res, ok := m.results[paymentID]
if !ok {
return nil, htlcswitch.ErrPaymentIDNotFound
}
c <- res
return c, nil
}
func (m *mockPaymentAttemptDispatcher) setPaymentResult(
f func(firstHop lnwire.ShortChannelID) ([32]byte, error)) {
m.onPayment = f
}
type mockPaymentSessionSource struct {
routes []*route.Route
}
var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)
func (m *mockPaymentSessionSource) NewPaymentSession(routeHints [][]zpay32.HopHint,
target route.Vertex) (PaymentSession, error) {
return &mockPaymentSession{m.routes}, nil
}
func (m *mockPaymentSessionSource) NewPaymentSessionForRoute(
preBuiltRoute *route.Route) PaymentSession {
return nil
}
func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
return &mockPaymentSession{}
}
type mockMissionControl struct {
}
var _ MissionController = (*mockMissionControl)(nil)
func (m *mockMissionControl) ReportPaymentFail(paymentID uint64, rt *route.Route,
failureSourceIdx *int, failure lnwire.FailureMessage) (
*channeldb.FailureReason, error) {
return nil, nil
}
func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
rt *route.Route) error {
return nil
}
func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
amt lnwire.MilliSatoshi) float64 {
return 0
}
type mockPaymentSession struct {
routes []*route.Route
}
var _ PaymentSession = (*mockPaymentSession)(nil)
func (m *mockPaymentSession) RequestRoute(payment *LightningPayment,
height uint32, finalCltvDelta uint16) (*route.Route, error) {
if len(m.routes) == 0 {
return nil, fmt.Errorf("no routes")
}
r := m.routes[0]
m.routes = m.routes[1:]
return r, nil
}
type mockPayer struct {
sendResult chan error
paymentResultErr chan error
paymentResult chan *htlcswitch.PaymentResult
quit chan struct{}
}
var _ PaymentAttemptDispatcher = (*mockPayer)(nil)
func (m *mockPayer) SendHTLC(_ lnwire.ShortChannelID,
paymentID uint64,
_ *lnwire.UpdateAddHTLC) error {
select {
case res := <-m.sendResult:
return res
case <-m.quit:
return fmt.Errorf("test quitting")
}
}
func (m *mockPayer) GetPaymentResult(paymentID uint64, _ lntypes.Hash,
_ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {
select {
case res := <-m.paymentResult:
resChan := make(chan *htlcswitch.PaymentResult, 1)
resChan <- res
return resChan, nil
case err := <-m.paymentResultErr:
return nil, err
case <-m.quit:
return nil, fmt.Errorf("test quitting")
}
}
type initArgs struct {
c *channeldb.PaymentCreationInfo
}
type registerArgs struct {
a *channeldb.PaymentAttemptInfo
}
type successArgs struct {
preimg lntypes.Preimage
}
type failArgs struct {
reason channeldb.FailureReason
}
type mockControlTower struct {
inflights map[lntypes.Hash]channeldb.InFlightPayment
successful map[lntypes.Hash]struct{}
init chan initArgs
register chan registerArgs
success chan successArgs
fail chan failArgs
fetchInFlight chan struct{}
sync.Mutex
}
var _ ControlTower = (*mockControlTower)(nil)
func makeMockControlTower() *mockControlTower {
return &mockControlTower{
inflights: make(map[lntypes.Hash]channeldb.InFlightPayment),
successful: make(map[lntypes.Hash]struct{}),
}
}
func (m *mockControlTower) InitPayment(phash lntypes.Hash,
c *channeldb.PaymentCreationInfo) error {
m.Lock()
defer m.Unlock()
if m.init != nil {
m.init <- initArgs{c}
}
if _, ok := m.successful[phash]; ok {
return fmt.Errorf("already successful")
}
_, ok := m.inflights[phash]
if ok {
return fmt.Errorf("in flight")
}
m.inflights[phash] = channeldb.InFlightPayment{
Info: c,
}
return nil
}
func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
a *channeldb.PaymentAttemptInfo) error {
m.Lock()
defer m.Unlock()
if m.register != nil {
m.register <- registerArgs{a}
}
p, ok := m.inflights[phash]
if !ok {
return fmt.Errorf("not in flight")
}
p.Attempt = a
m.inflights[phash] = p
return nil
}
func (m *mockControlTower) Success(phash lntypes.Hash,
preimg lntypes.Preimage) error {
m.Lock()
defer m.Unlock()
if m.success != nil {
m.success <- successArgs{preimg}
}
delete(m.inflights, phash)
m.successful[phash] = struct{}{}
return nil
}
func (m *mockControlTower) Fail(phash lntypes.Hash,
reason channeldb.FailureReason) error {
m.Lock()
defer m.Unlock()
if m.fail != nil {
m.fail <- failArgs{reason}
}
delete(m.inflights, phash)
return nil
}
func (m *mockControlTower) FetchInFlightPayments() (
[]*channeldb.InFlightPayment, error) {
m.Lock()
defer m.Unlock()
if m.fetchInFlight != nil {
m.fetchInFlight <- struct{}{}
}
var fl []*channeldb.InFlightPayment
for _, ifl := range m.inflights {
fl = append(fl, &ifl)
}
return fl, nil
}
func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
bool, chan PaymentResult, error) {
return false, nil, errors.New("not implemented")
}