diff --git a/lnd_test.go b/lnd_test.go index eed2bc56..73cf7f73 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -576,10 +576,20 @@ func calcStaticFee(numHTLCs int) btcutil.Amount { func completePaymentRequests(ctx context.Context, client lnrpc.LightningClient, paymentRequests []string, awaitResponse bool) error { - ctx, cancel := context.WithCancel(ctx) + // We start by getting the current state of the client's channels. This + // is needed to ensure the payments actually have been committed before + // we return. + ctxt, _ := context.WithTimeout(ctx, defaultTimeout) + req := &lnrpc.ListChannelsRequest{} + listResp, err := client.ListChannels(ctxt, req) + if err != nil { + return err + } + + ctxc, cancel := context.WithCancel(ctx) defer cancel() - payStream, err := client.SendPayment(ctx) + payStream, err := client.SendPayment(ctxc) if err != nil { return err } @@ -605,11 +615,40 @@ func completePaymentRequests(ctx context.Context, client lnrpc.LightningClient, resp.PaymentError) } } - } else { - // We are not waiting for feedback in the form of a response, but we - // should still wait long enough for the server to receive and handle - // the send before cancelling the request. - time.Sleep(200 * time.Millisecond) + + return nil + } + + // We are not waiting for feedback in the form of a response, but we + // should still wait long enough for the server to receive and handle + // the send before cancelling the request. We wait for the number of + // updates to one of our channels has increased before we return. + err = lntest.WaitPredicate(func() bool { + ctxt, _ = context.WithTimeout(ctx, defaultTimeout) + newListResp, err := client.ListChannels(ctxt, req) + if err != nil { + return false + } + + for _, c1 := range listResp.Channels { + for _, c2 := range newListResp.Channels { + if c1.ChannelPoint != c2.ChannelPoint { + continue + } + + // If this channel has an increased numbr of + // updates, we assume the payments are + // committed, and we can return. + if c2.NumUpdates > c1.NumUpdates { + return true + } + } + } + + return false + }, time.Second*15) + if err != nil { + return err } return nil