diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index ca5b6998..1ba84668 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -186,38 +186,39 @@ func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, // RegisterAttempt atomically records the provided HTLCAttemptInfo to the // DB. func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, - attempt *HTLCAttemptInfo) error { + attempt *HTLCAttemptInfo) (*MPPayment, error) { // Serialize the information before opening the db transaction. var a bytes.Buffer err := serializeHTLCAttemptInfo(&a, attempt) if err != nil { - return err + return nil, err } htlcInfoBytes := a.Bytes() htlcIDBytes := make([]byte, 8) binary.BigEndian.PutUint64(htlcIDBytes, attempt.AttemptID) - return kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error { + var payment *MPPayment + err = kvdb.Batch(p.db.Backend, func(tx kvdb.RwTx) error { bucket, err := fetchPaymentBucketUpdate(tx, paymentHash) if err != nil { return err } - payment, err := fetchPayment(bucket) + p, err := fetchPayment(bucket) if err != nil { return err } // Ensure the payment is in-flight. - if err := ensureInFlight(payment); err != nil { + if err := ensureInFlight(p); err != nil { return err } // We cannot register a new attempt if the payment already has // reached a terminal condition: - settle, fail := payment.TerminalInfo() + settle, fail := p.TerminalInfo() if settle != nil || fail != nil { return ErrPaymentTerminal } @@ -225,7 +226,7 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, // Make sure any existing shards match the new one with regards // to MPP options. mpp := attempt.Route.FinalHop().MPP - for _, h := range payment.InFlightHTLCs() { + for _, h := range p.InFlightHTLCs() { hMpp := h.Route.FinalHop().MPP switch { @@ -258,13 +259,13 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, // If this is a non-MPP attempt, it must match the total amount // exactly. amt := attempt.Route.ReceiverAmt() - if mpp == nil && amt != payment.Info.Value { + if mpp == nil && amt != p.Info.Value { return ErrValueMismatch } // Ensure we aren't sending more than the total payment amount. - sentAmt, _ := payment.SentAmt() - if sentAmt+amt > payment.Info.Value { + sentAmt, _ := p.SentAmt() + if sentAmt+amt > p.Info.Value { return ErrValueExceedsAmt } @@ -282,8 +283,20 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, return err } - return htlcBucket.Put(htlcAttemptInfoKey, htlcInfoBytes) + err = htlcBucket.Put(htlcAttemptInfoKey, htlcInfoBytes) + if err != nil { + return err + } + + // Retrieve attempt info for the notification. + payment, err = fetchPayment(bucket) + return err }) + if err != nil { + return nil, err + } + + return payment, err } // SettleAttempt marks the given attempt settled with the preimage. If this is @@ -307,16 +320,15 @@ func (p *PaymentControl) SettleAttempt(hash lntypes.Hash, // FailAttempt marks the given payment attempt failed. func (p *PaymentControl) FailAttempt(hash lntypes.Hash, - attemptID uint64, failInfo *HTLCFailInfo) error { + attemptID uint64, failInfo *HTLCFailInfo) (*MPPayment, error) { var b bytes.Buffer if err := serializeHTLCFailInfo(&b, failInfo); err != nil { - return err + return nil, err } failBytes := b.Bytes() - _, err := p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes) - return err + return p.updateHtlcKey(hash, attemptID, htlcFailInfoKey, failBytes) } // updateHtlcKey updates a database key for the specified htlc. diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index abd2722a..95862f5e 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -117,13 +117,13 @@ func TestPaymentControlSwitchFail(t *testing.T) { // Record a new attempt. In this test scenario, the attempt fails. // However, this is not communicated to control tower in the current // implementation. It only registers the initiation of the attempt. - err = pControl.RegisterAttempt(info.PaymentHash, attempt) + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to register attempt: %v", err) } htlcReason := HTLCFailUnreadable - err = pControl.FailAttempt( + _, err = pControl.FailAttempt( info.PaymentHash, attempt.AttemptID, &HTLCFailInfo{ Reason: htlcReason, @@ -143,7 +143,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { // Record another attempt. attempt.AttemptID = 1 - err = pControl.RegisterAttempt(info.PaymentHash, attempt) + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } @@ -236,7 +236,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { } // Record an attempt. - err = pControl.RegisterAttempt(info.PaymentHash, attempt) + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } @@ -375,7 +375,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { if err != nil { t.Fatalf("unable to send htlc message: %v", err) } - err = pControl.RegisterAttempt(info.PaymentHash, attempt) + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } @@ -387,7 +387,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { if p.failed { // Fail the payment attempt. htlcFailure := HTLCFailUnreadable - err := pControl.FailAttempt( + _, err := pControl.FailAttempt( info.PaymentHash, attempt.AttemptID, &HTLCFailInfo{ Reason: htlcFailure, @@ -520,7 +520,7 @@ func TestPaymentControlMultiShard(t *testing.T) { a.AttemptID = i attempts = append(attempts, &a) - err = pControl.RegisterAttempt(info.PaymentHash, &a) + _, err = pControl.RegisterAttempt(info.PaymentHash, &a) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } @@ -541,7 +541,7 @@ func TestPaymentControlMultiShard(t *testing.T) { // will be too large. b := *attempt b.AttemptID = 3 - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrValueExceedsAmt { t.Fatalf("expected ErrValueExceedsAmt, got: %v", err) @@ -550,7 +550,7 @@ func TestPaymentControlMultiShard(t *testing.T) { // Fail the second attempt. a := attempts[1] htlcFail := HTLCFailUnreadable - err = pControl.FailAttempt( + _, err = pControl.FailAttempt( info.PaymentHash, a.AttemptID, &HTLCFailInfo{ Reason: htlcFail, @@ -596,7 +596,7 @@ func TestPaymentControlMultiShard(t *testing.T) { t, pControl, info.PaymentHash, info, nil, htlc, ) } else { - err := pControl.FailAttempt( + _, err := pControl.FailAttempt( info.PaymentHash, a.AttemptID, &HTLCFailInfo{ Reason: htlcFail, @@ -634,7 +634,7 @@ func TestPaymentControlMultiShard(t *testing.T) { // that the payment has reached a terminal condition. b = *attempt b.AttemptID = 3 - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrPaymentTerminal { t.Fatalf("expected ErrPaymentTerminal, got: %v", err) } @@ -666,7 +666,7 @@ func TestPaymentControlMultiShard(t *testing.T) { ) } else { // Fail the attempt. - err := pControl.FailAttempt( + _, err := pControl.FailAttempt( info.PaymentHash, a.AttemptID, &HTLCFailInfo{ Reason: htlcFail, @@ -708,7 +708,7 @@ func TestPaymentControlMultiShard(t *testing.T) { assertPaymentStatus(t, pControl, info.PaymentHash, finalStatus) // Finally assert we cannot register more attempts. - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != expRegErr { t.Fatalf("expected error %v, got: %v", expRegErr, err) } @@ -756,7 +756,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { info.Value, [32]byte{1}, ) - err = pControl.RegisterAttempt(info.PaymentHash, attempt) + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } @@ -765,7 +765,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { b := *attempt b.AttemptID = 1 b.Route.FinalHop().MPP = nil - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrMPPayment { t.Fatalf("expected ErrMPPayment, got: %v", err) } @@ -774,7 +774,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { b.Route.FinalHop().MPP = record.NewMPP( info.Value, [32]byte{2}, ) - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrMPPPaymentAddrMismatch { t.Fatalf("expected ErrMPPPaymentAddrMismatch, got: %v", err) } @@ -783,7 +783,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { b.Route.FinalHop().MPP = record.NewMPP( info.Value/2, [32]byte{1}, ) - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrMPPTotalAmountMismatch { t.Fatalf("expected ErrMPPTotalAmountMismatch, got: %v", err) } @@ -801,7 +801,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { } attempt.Route.FinalHop().MPP = nil - err = pControl.RegisterAttempt(info.PaymentHash, attempt) + _, err = pControl.RegisterAttempt(info.PaymentHash, attempt) if err != nil { t.Fatalf("unable to send htlc message: %v", err) } @@ -813,7 +813,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { info.Value, [32]byte{1}, ) - err = pControl.RegisterAttempt(info.PaymentHash, &b) + _, err = pControl.RegisterAttempt(info.PaymentHash, &b) if err != ErrNonMPPayment { t.Fatalf("expected ErrNonMPPayment, got: %v", err) } diff --git a/routing/control_tower.go b/routing/control_tower.go index da7d4d2b..5702eb92 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -107,7 +107,8 @@ func (p *controlTower) InitPayment(paymentHash lntypes.Hash, func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, attempt *channeldb.HTLCAttemptInfo) error { - return p.db.RegisterAttempt(paymentHash, attempt) + _, err := p.db.RegisterAttempt(paymentHash, attempt) + return err } // SettleAttempt marks the given attempt settled with the preimage. If @@ -133,7 +134,8 @@ func (p *controlTower) SettleAttempt(paymentHash lntypes.Hash, func (p *controlTower) FailAttempt(paymentHash lntypes.Hash, attemptID uint64, failInfo *channeldb.HTLCFailInfo) error { - return p.db.FailAttempt(paymentHash, attemptID, failInfo) + _, err := p.db.FailAttempt(paymentHash, attemptID, failInfo) + return err } // FetchPayment fetches the payment corresponding to the given payment hash.