From 62a9c2c3acc334172fcef1ad063f709a07ea8547 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 10 Sep 2019 15:34:02 +0200 Subject: [PATCH] channeldb/test: make route comparison a pure function Previously the route to compare was modified in order for DeepEqual to function properly. This created problems when tests were ran in parallel. --- channeldb/payment_control_test.go | 62 ++----------------------------- channeldb/payments_test.go | 58 +++++++++++++++++++---------- 2 files changed, 42 insertions(+), 78 deletions(-) diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index b51c9bdf..479e6898 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -15,7 +15,6 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" ) func initDB() (*DB, error) { @@ -139,36 +138,10 @@ func TestPaymentControlSwitchFail(t *testing.T) { t.Fatalf("error shouldn't have been received, got: %v", err) } - err = assertRouteHopRecordsEqual(route, &attempt.Route) + err = assertRouteEqual(route, &attempt.Route) if err != nil { - t.Fatalf("route tlv records not equal: %v", err) - } - - for i := 0; i < len(route.Hops); i++ { - for j := 0; j < len(route.Hops[i].TLVRecords); j++ { - expectedRecord := route.Hops[i].TLVRecords[j] - newRecord := attempt.Route.Hops[i].TLVRecords[j] - - err := assertHopRecordsEqual(expectedRecord, newRecord) - if err != nil { - t.Fatalf("route record mismatch: %v", err) - } - } - } - - for i := 0; i < len(route.Hops); i++ { - // reflect.DeepEqual can't assert that two function closures - // are equal. The underlying tlv.Record uses function closures - // internally, so after we verify that the records match above - // manually, we unset these so we can use reflect.DeepEqual - // below. - route.Hops[i].TLVRecords = nil - attempt.Route.Hops[i].TLVRecords = nil - } - - if !reflect.DeepEqual(*route, attempt.Route) { - t.Fatalf("unexpected route returned: %v vs %v", - spew.Sdump(attempt.Route), spew.Sdump(*route)) + t.Fatalf("unexpected route returned: %v vs %v: %v", + spew.Sdump(attempt.Route), spew.Sdump(*route), err) } assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) @@ -484,34 +457,7 @@ func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error return err } - err = assertRouteHopRecordsEqual(&a.Route, &a2.Route) - if err != nil { - return err - } - - recordCache := make(map[int][]tlv.Record) - for i := 0; i < len(a.Route.Hops); i++ { - recordCache[i] = a.Route.Hops[i].TLVRecords - - // reflect.DeepEqual can't assert that two function closures - // are equal. The underlying tlv.Record uses function closures - // internally, so after we verify that the records match above - // manually, we unset these so we can use reflect.DeepEqual - // below. - a.Route.Hops[i].TLVRecords = nil - a2.Route.Hops[i].TLVRecords = nil - } - - if !reflect.DeepEqual(a, a2) { - return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", - spew.Sdump(a), spew.Sdump(a2)) - } - - for index, records := range recordCache { - a.Route.Hops[index].TLVRecords = records - } - - return nil + return assertRouteEqual(&a.Route, &a2.Route) } func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 109d0adf..a979a956 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -2,6 +2,7 @@ package channeldb import ( "bytes" + "errors" "fmt" "math/rand" "reflect" @@ -207,14 +208,15 @@ func TestSentPaymentSerialization(t *testing.T) { // First we verify all the records match up porperly, as they aren't // able to be properly compared using reflect.DeepEqual. - assertRouteHopRecordsEqual(&s.Route, &newAttemptInfo.Route) + err = assertRouteEqual(&s.Route, &newAttemptInfo.Route) + if err != nil { + t.Fatalf("Routes do not match after "+ + "serialization/deserialization: %v", err) + } - // With the hop recrods, equal, we'll now blank them out as - // reflect.DeepEqual can't properly compare tlv.Record instances. - newAttemptInfo.Route.Hops[0].TLVRecords = nil - newAttemptInfo.Route.Hops[1].TLVRecords = nil - s.Route.Hops[0].TLVRecords = nil - s.Route.Hops[1].TLVRecords = nil + // Clear routes to allow DeepEqual to compare the remaining fields. + newAttemptInfo.Route = route.Route{} + s.Route = route.Route{} if !reflect.DeepEqual(s, newAttemptInfo) { s.SessionKey.Curve = nil @@ -224,7 +226,35 @@ func TestSentPaymentSerialization(t *testing.T) { spew.Sdump(s), spew.Sdump(newAttemptInfo), ) } +} +// assertRouteEquals compares to routes for equality and returns an error if +// they are not equal. +func assertRouteEqual(a, b *route.Route) error { + err := assertRouteHopRecordsEqual(a, b) + if err != nil { + return err + } + + // TLV records have already been compared and need to be cleared to + // properly compare the remaining fields using DeepEqual. + copyRouteNoHops := func(r *route.Route) *route.Route { + copy := *r + copy.Hops = make([]*route.Hop, len(r.Hops)) + for i, hop := range r.Hops { + hopCopy := *hop + hopCopy.TLVRecords = nil + copy.Hops[i] = &hopCopy + } + return © + } + + if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) { + return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", + spew.Sdump(a), spew.Sdump(b)) + } + + return nil } func assertRouteHopRecordsEqual(r1, r2 *route.Route) error { @@ -294,20 +324,8 @@ func TestRouteSerialization(t *testing.T) { // First we verify all the records match up porperly, as they aren't // able to be properly compared using reflect.DeepEqual. - err = assertRouteHopRecordsEqual(&testRoute, &route2) + err = assertRouteEqual(&testRoute, &route2) if err != nil { - t.Fatalf("route tlv records don't match: %v", err) - } - - // Now that we know the records match up, we'll examine the remainder - // of the route without the TLV records attached as reflect.DeepEqual - // can't properly assert their equality. - testRoute.Hops[0].TLVRecords = nil - testRoute.Hops[1].TLVRecords = nil - route2.Hops[0].TLVRecords = nil - route2.Hops[1].TLVRecords = nil - - if !reflect.DeepEqual(testRoute, route2) { t.Fatalf("routes not equal: \n%v vs \n%v", spew.Sdump(testRoute), spew.Sdump(route2)) }