diff --git a/breacharbiter_test.go b/breacharbiter_test.go index 358a826a..422e6be7 100644 --- a/breacharbiter_test.go +++ b/breacharbiter_test.go @@ -1604,7 +1604,12 @@ func testBreachSpends(t *testing.T, test breachTest) { publMtx.Lock() err := publErr publMtx.Unlock() - publTx <- tx + + select { + case publTx <- tx: + case <-brar.quit: + return fmt.Errorf("brar quit") + } return err } @@ -1817,7 +1822,11 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { // Make PublishTransaction always return succeed. brar.cfg.PublishTransaction = func(tx *wire.MsgTx, _ string) error { - publTx <- tx + select { + case publTx <- tx: + case <-brar.quit: + return fmt.Errorf("brar quit") + } return nil } @@ -1969,13 +1978,34 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { require.Len(t, spending, len(justiceTx.TxIn)) require.Len(t, splits, 2) - // Finally notify that they confirm, making the breach arbiter clean - // up. - for _, tx := range splits { - for _, in := range tx.TxIn { - op := &in.PreviousOutPoint - notifier.Spend(op, blockHeight+5, tx) + // Notify that the first split confirm, making the breach arbiter + // publish another TX with the remaining inputs. + for _, in := range splits[0].TxIn { + op := &in.PreviousOutPoint + notifier.Spend(op, blockHeight+5, splits[0]) + } + + select { + + // The published tx should spend the same inputs as our second split. + case tx := <-publTx: + require.Len(t, tx.TxIn, len(splits[1].TxIn)) + for i := range tx.TxIn { + require.Equal( + t, tx.TxIn[i].PreviousOutPoint, + splits[1].TxIn[i].PreviousOutPoint, + ) } + + case <-time.After(5 * time.Second): + t.Fatalf("tx not published") + } + + // Finally notify that the second split confirms, making the breach + // arbiter clean up since all inputs have been swept. + for _, in := range splits[1].TxIn { + op := &in.PreviousOutPoint + notifier.Spend(op, blockHeight+6, splits[1]) } // Assert that the channel is fully resolved. @@ -2080,7 +2110,7 @@ func assertBrarCleanup(t *testing.T, brar *breachArbiter, return fmt.Errorf("channel %v not closed", chanPoint) - }, time.Second) + }, 5*time.Second) if err != nil { t.Fatalf(err.Error()) }