diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index df059914..1f681264 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -916,3 +916,58 @@ func (m *mockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, _ []byte, Spend: make(chan *chainntnfs.SpendDetail), }, nil } + +type mockCircuitMap struct { + lookup chan *PaymentCircuit +} + +var _ CircuitMap = (*mockCircuitMap)(nil) + +func (m *mockCircuitMap) OpenCircuits(...Keystone) error { + return nil +} + +func (m *mockCircuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID, + start uint64) error { + return nil +} + +func (m *mockCircuitMap) DeleteCircuits(inKeys ...CircuitKey) error { + return nil +} + +func (m *mockCircuitMap) CommitCircuits( + circuit ...*PaymentCircuit) (*CircuitFwdActions, error) { + + return nil, nil +} + +func (m *mockCircuitMap) CloseCircuit(outKey CircuitKey) (*PaymentCircuit, + error) { + return nil, nil +} + +func (m *mockCircuitMap) FailCircuit(inKey CircuitKey) (*PaymentCircuit, + error) { + return nil, nil +} + +func (m *mockCircuitMap) LookupCircuit(inKey CircuitKey) *PaymentCircuit { + return <-m.lookup +} + +func (m *mockCircuitMap) LookupOpenCircuit(outKey CircuitKey) *PaymentCircuit { + return nil +} + +func (m *mockCircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit { + return nil +} + +func (m *mockCircuitMap) NumPending() int { + return 0 +} + +func (m *mockCircuitMap) NumOpen() int { + return 0 +} diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index af7c0fca..6a0a24e9 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -14,6 +14,7 @@ import ( "github.com/btcsuite/fastsha256" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/ticker" ) @@ -2125,3 +2126,115 @@ func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) { assertPaymentFailure(t) }) } + +// TestSwitchGetPaymentResult tests that the switch interacts as expected with +// the circuit map and network result store when looking up the result of a +// payment ID. This is important for not to lose results under concurrent +// lookup and receiving results. +func TestSwitchGetPaymentResult(t *testing.T) { + t.Parallel() + + const paymentID = 123 + var preimg lntypes.Preimage + preimg[0] = 3 + + s, err := initSwitchWithDB(testStartingHeight, nil) + if err != nil { + t.Fatalf("unable to init switch: %v", err) + } + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + lookup := make(chan *PaymentCircuit, 1) + s.circuits = &mockCircuitMap{ + lookup: lookup, + } + + // If the payment circuit is not found in the circuit map, the payment + // result must be found in the store if available. Since we haven't + // added anything to the store yet, ErrPaymentIDNotFound should be + // returned. + lookup <- nil + _, err = s.GetPaymentResult( + paymentID, lntypes.Hash{}, newMockDeobfuscator(), + ) + if err != ErrPaymentIDNotFound { + t.Fatalf("expected ErrPaymentIDNotFound, got %v", err) + } + + // Next let the lookup find the circuit in the circuit map. It should + // subscribe to payment results, and return the result when available. + lookup <- &PaymentCircuit{} + resultChan, err := s.GetPaymentResult( + paymentID, lntypes.Hash{}, newMockDeobfuscator(), + ) + if err != nil { + t.Fatalf("unable to get payment result: %v", err) + } + + // Add the result to the store. + n := &networkResult{ + msg: &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: preimg, + }, + unencrypted: true, + isResolution: true, + } + + err = s.networkResults.storeResult(paymentID, n) + if err != nil { + t.Fatalf("unable to store result: %v", err) + } + + // The result should be availble. + select { + case res, ok := <-resultChan: + if !ok { + t.Fatalf("channel was closed") + } + + if res.Error != nil { + t.Fatalf("got unexpected error result") + } + + if res.Preimage != preimg { + t.Fatalf("expected preimg %v, got %v", + preimg, res.Preimage) + } + + case <-time.After(1 * time.Second): + t.Fatalf("result not received") + } + + // As a final test, try to get the result again. Now that is no longer + // in the circuit map, it should be immediately available from the + // store. + lookup <- nil + resultChan, err = s.GetPaymentResult( + paymentID, lntypes.Hash{}, newMockDeobfuscator(), + ) + if err != nil { + t.Fatalf("unable to get payment result: %v", err) + } + + select { + case res, ok := <-resultChan: + if !ok { + t.Fatalf("channel was closed") + } + + if res.Error != nil { + t.Fatalf("got unexpected error result") + } + + if res.Preimage != preimg { + t.Fatalf("expected preimg %v, got %v", + preimg, res.Preimage) + } + + case <-time.After(1 * time.Second): + t.Fatalf("result not received") + } +}