channeldb: update route.Hop serialization to include new EOB related fields

We also include a migration for the existing routes stored on disk.
This commit is contained in:
Olaoluwa Osuntokun 2019-07-30 21:44:50 -07:00
parent 763cb6c09d
commit c78e3aaa9d
No known key found for this signature in database
GPG Key ID: CE58F7F8E20FD9A2
6 changed files with 677 additions and 6 deletions

@ -103,6 +103,13 @@ var (
number: 9,
migration: migrateOutgoingPayments,
},
{
// The DB version where we started to store legacy
// payload information for all routes, as well as the
// optional TLV records.
number: 10,
migration: migrateRouteSerialization,
},
}
// Big endian is the preferred byte order, due to cursor scans over

@ -0,0 +1,236 @@
package channeldb
import (
"bytes"
"io"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/routing/route"
)
// migrateRouteSerialization migrates the way we serialize routes across the
// entire database. At the time of writing of this migration, this includes our
// payment attempts, as well as the payment results in mission control.
func migrateRouteSerialization(tx *bbolt.Tx) error {
// First, we'll do all the payment attempts.
rootPaymentBucket := tx.Bucket(paymentsRootBucket)
if rootPaymentBucket == nil {
return nil
}
// As we can't mutate a bucket while we're iterating over it with
// ForEach, we'll need to collect all the known payment hashes in
// memory first.
var payHashes [][]byte
err := rootPaymentBucket.ForEach(func(k, v []byte) error {
if v != nil {
return nil
}
payHashes = append(payHashes, k)
return nil
})
if err != nil {
return err
}
// Now that we have all the payment hashes, we can carry out the
// migration itself.
for _, payHash := range payHashes {
payHashBucket := rootPaymentBucket.Bucket(payHash)
// First, we'll migrate the main (non duplicate) payment to
// this hash.
err := migrateAttemptEncoding(tx, payHashBucket)
if err != nil {
return err
}
// Now that we've migrated the main payment, we'll also check
// for any duplicate payments to the same payment hash.
dupBucket := payHashBucket.Bucket(paymentDuplicateBucket)
// If there's no dup bucket, then we can move on to the next
// payment.
if dupBucket == nil {
continue
}
// Otherwise, we'll now iterate through all the duplicate pay
// hashes and migrate those.
var dupSeqNos [][]byte
err = dupBucket.ForEach(func(k, v []byte) error {
dupSeqNos = append(dupSeqNos, k)
return nil
})
if err != nil {
return err
}
// Now in this second pass, we'll re-serialize their duplicate
// payment attempts under the new encoding.
for _, seqNo := range dupSeqNos {
dupPayHashBucket := dupBucket.Bucket(seqNo)
err := migrateAttemptEncoding(tx, dupPayHashBucket)
if err != nil {
return err
}
}
}
log.Infof("Migration of route/hop serialization complete!")
log.Infof("Migrating to new mission control store by clearing " +
"existing data")
resultsKey := []byte("missioncontrol-results")
err = tx.DeleteBucket(resultsKey)
if err != nil && err != bbolt.ErrBucketNotFound {
return err
}
log.Infof("Migration to new mission control completed!")
return nil
}
// migrateAttemptEncoding migrates payment attempts using the legacy format to
// the new format.
func migrateAttemptEncoding(tx *bbolt.Tx, payHashBucket *bbolt.Bucket) error {
payAttemptBytes := payHashBucket.Get(paymentAttemptInfoKey)
if payAttemptBytes == nil {
return nil
}
// For our migration, we'll first read out the existing payment attempt
// using the legacy serialization of the attempt.
payAttemptReader := bytes.NewReader(payAttemptBytes)
payAttempt, err := deserializePaymentAttemptInfoLegacy(
payAttemptReader,
)
if err != nil {
return err
}
// Now that we have the old attempts, we'll explicitly mark this as
// needing a legacy payload, since after this migration, the modern
// payload will be the default if signalled.
for _, hop := range payAttempt.Route.Hops {
hop.LegacyPayload = true
}
// Finally, we'll write out the payment attempt using the new encoding.
var b bytes.Buffer
err = serializePaymentAttemptInfo(&b, payAttempt)
if err != nil {
return err
}
return payHashBucket.Put(paymentAttemptInfoKey, b.Bytes())
}
func deserializePaymentAttemptInfoLegacy(r io.Reader) (*PaymentAttemptInfo, error) {
a := &PaymentAttemptInfo{}
err := ReadElements(r, &a.PaymentID, &a.SessionKey)
if err != nil {
return nil, err
}
a.Route, err = deserializeRouteLegacy(r)
if err != nil {
return nil, err
}
return a, nil
}
func serializePaymentAttemptInfoLegacy(w io.Writer, a *PaymentAttemptInfo) error {
if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil {
return err
}
if err := serializeRouteLegacy(w, a.Route); err != nil {
return err
}
return nil
}
func deserializeHopLegacy(r io.Reader) (*route.Hop, error) {
h := &route.Hop{}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return nil, err
}
copy(h.PubKeyBytes[:], pub)
if err := ReadElements(r,
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
); err != nil {
return nil, err
}
return h, nil
}
func serializeHopLegacy(w io.Writer, h *route.Hop) error {
if err := WriteElements(w,
h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock,
h.AmtToForward,
); err != nil {
return err
}
return nil
}
func deserializeRouteLegacy(r io.Reader) (route.Route, error) {
rt := route.Route{}
if err := ReadElements(r,
&rt.TotalTimeLock, &rt.TotalAmount,
); err != nil {
return rt, err
}
var pub []byte
if err := ReadElements(r, &pub); err != nil {
return rt, err
}
copy(rt.SourcePubKey[:], pub)
var numHops uint32
if err := ReadElements(r, &numHops); err != nil {
return rt, err
}
var hops []*route.Hop
for i := uint32(0); i < numHops; i++ {
hop, err := deserializeHopLegacy(r)
if err != nil {
return rt, err
}
hops = append(hops, hop)
}
rt.Hops = hops
return rt, nil
}
func serializeRouteLegacy(w io.Writer, r route.Route) error {
if err := WriteElements(w,
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
); err != nil {
return err
}
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
return err
}
for _, h := range r.Hops {
if err := serializeHopLegacy(w, h); err != nil {
return err
}
}
return nil
}

@ -5,14 +5,18 @@ import (
"crypto/sha256"
"encoding/binary"
"fmt"
"math/rand"
"reflect"
"testing"
"time"
"github.com/btcsuite/btcutil"
"github.com/coreos/bbolt"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
)
// TestPaymentStatusesMigration checks that already completed payments will have
@ -723,3 +727,223 @@ func TestOutgoingPaymentsMigration(t *testing.T) {
migrateOutgoingPayments,
false)
}
func makeRandPaymentCreationInfo() (*PaymentCreationInfo, error) {
var payHash lntypes.Hash
if _, err := rand.Read(payHash[:]); err != nil {
return nil, err
}
return &PaymentCreationInfo{
PaymentHash: payHash,
Value: lnwire.MilliSatoshi(rand.Int63()),
CreationDate: time.Now(),
PaymentRequest: []byte("test"),
}, nil
}
// TestPaymentRouteSerialization tests that we're able to properly migrate
// existing payments on disk that contain the traversed routes to the new
// routing format which supports the TLV payloads. We also test that the
// migration is able to handle duplicate payment attempts.
func TestPaymentRouteSerialization(t *testing.T) {
t.Parallel()
legacyHop1 := &route.Hop{
PubKeyBytes: route.NewVertex(pub),
ChannelID: 12345,
OutgoingTimeLock: 111,
LegacyPayload: true,
AmtToForward: 555,
}
legacyHop2 := &route.Hop{
PubKeyBytes: route.NewVertex(pub),
ChannelID: 12345,
OutgoingTimeLock: 111,
LegacyPayload: true,
AmtToForward: 555,
}
legacyRoute := route.Route{
TotalTimeLock: 123,
TotalAmount: 1234567,
SourcePubKey: route.NewVertex(pub),
Hops: []*route.Hop{legacyHop1, legacyHop2},
}
const numPayments = 4
var oldPayments []*Payment
sharedPayAttempt := PaymentAttemptInfo{
PaymentID: 1,
SessionKey: priv,
Route: legacyRoute,
}
// We'll first add a series of fake payments, using the existing legacy
// serialization format.
beforeMigrationFunc := func(d *DB) {
err := d.Update(func(tx *bbolt.Tx) error {
paymentsBucket, err := tx.CreateBucket(
paymentsRootBucket,
)
if err != nil {
t.Fatalf("unable to create new payments "+
"bucket: %v", err)
}
for i := 0; i < numPayments; i++ {
var seqNum [8]byte
byteOrder.PutUint64(seqNum[:], uint64(i))
// All payments will be randomly generated,
// other than the final payment. We'll force
// the final payment to re-use an existing
// payment hash so we can insert it into the
// duplicate payment hash bucket.
var payInfo *PaymentCreationInfo
if i < numPayments-1 {
payInfo, err = makeRandPaymentCreationInfo()
if err != nil {
t.Fatalf("unable to create "+
"payment: %v", err)
}
} else {
payInfo = oldPayments[0].Info
}
// Next, legacy encoded when needed, we'll
// serialize the info and the attempt.
var payInfoBytes bytes.Buffer
err = serializePaymentCreationInfo(
&payInfoBytes, payInfo,
)
if err != nil {
t.Fatalf("unable to encode pay "+
"info: %v", err)
}
var payAttemptBytes bytes.Buffer
err = serializePaymentAttemptInfoLegacy(
&payAttemptBytes, &sharedPayAttempt,
)
if err != nil {
t.Fatalf("unable to encode payment attempt: "+
"%v", err)
}
// Before we write to disk, we'll need to fetch
// the proper bucket. If this is the duplicate
// payment, then we'll grab the dup bucket,
// otherwise, we'll use the top level bucket.
var payHashBucket *bbolt.Bucket
if i < numPayments-1 {
payHashBucket, err = paymentsBucket.CreateBucket(
payInfo.PaymentHash[:],
)
} else {
payHashBucket = paymentsBucket.Bucket(
payInfo.PaymentHash[:],
)
dupPayBucket, err := payHashBucket.CreateBucket(
paymentDuplicateBucket,
)
if err != nil {
t.Fatalf("unable to create "+
"dup hash bucket: %v", err)
}
payHashBucket, err = dupPayBucket.CreateBucket(
seqNum[:],
)
if err != nil {
t.Fatalf("unable to make dup "+
"bucket: %v", err)
}
}
err = payHashBucket.Put(paymentSequenceKey, seqNum[:])
if err != nil {
t.Fatalf("unable to write seqno: %v", err)
}
err = payHashBucket.Put(
paymentCreationInfoKey, payInfoBytes.Bytes(),
)
if err != nil {
t.Fatalf("unable to write creation "+
"info: %v", err)
}
err = payHashBucket.Put(
paymentAttemptInfoKey, payAttemptBytes.Bytes(),
)
if err != nil {
t.Fatalf("unable to write attempt "+
"info: %v", err)
}
oldPayments = append(oldPayments, &Payment{
Info: payInfo,
Attempt: &sharedPayAttempt,
})
}
return nil
})
if err != nil {
t.Fatalf("unable to create test payments: %v", err)
}
}
afterMigrationFunc := func(d *DB) {
newPayments, err := d.FetchPayments()
if err != nil {
t.Fatalf("unable to fetch new payments: %v", err)
}
if len(newPayments) != numPayments {
t.Fatalf("expected %d payments, got %d", numPayments,
len(newPayments))
}
for i, p := range newPayments {
// Order of payments should be be preserved.
old := oldPayments[i]
if p.Attempt.PaymentID != old.Attempt.PaymentID {
t.Fatalf("wrong pay ID: expected %v, got %v",
p.Attempt.PaymentID,
old.Attempt.PaymentID)
}
if p.Attempt.Route.TotalFees() != old.Attempt.Route.TotalFees() {
t.Fatalf("Fee mismatch")
}
if p.Attempt.Route.TotalAmount != old.Attempt.Route.TotalAmount {
t.Fatalf("Total amount mismatch")
}
if p.Attempt.Route.TotalTimeLock != old.Attempt.Route.TotalTimeLock {
t.Fatalf("timelock mismatch")
}
if p.Attempt.Route.SourcePubKey != old.Attempt.Route.SourcePubKey {
t.Fatalf("source mismatch: %x vs %x",
p.Attempt.Route.SourcePubKey[:],
old.Attempt.Route.SourcePubKey[:])
}
for i, hop := range p.Attempt.Route.Hops {
if !reflect.DeepEqual(hop, legacyRoute.Hops[i]) {
t.Fatalf("hop mismatch")
}
}
}
}
applyMigration(t,
beforeMigrationFunc,
afterMigrationFunc,
migrateRouteSerialization,
false)
}

@ -15,6 +15,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
func initDB() (*DB, error) {
@ -137,8 +138,37 @@ func TestPaymentControlSwitchFail(t *testing.T) {
if err != nil {
t.Fatalf("error shouldn't have been received, got: %v", err)
}
err = assertRouteHopRecordsEqual(route, &attempt.Route)
if err != nil {
t.Fatalf("route tlv records not equal: %v", err)
}
for i := 0; i < len(route.Hops); i++ {
for j := 0; j < len(route.Hops[i].TLVRecords); j++ {
expectedRecord := route.Hops[i].TLVRecords[j]
newRecord := attempt.Route.Hops[i].TLVRecords[j]
err := assertHopRecordsEqual(expectedRecord, newRecord)
if err != nil {
t.Fatalf("route record mismatch: %v", err)
}
}
}
for i := 0; i < len(route.Hops); i++ {
// reflect.DeepEqual can't assert that two function closures
// are equal. The underlying tlv.Record uses function closures
// internally, so after we verify that the records match above
// manually, we unset these so we can use reflect.DeepEqual
// below.
route.Hops[i].TLVRecords = nil
attempt.Route.Hops[i].TLVRecords = nil
}
if !reflect.DeepEqual(*route, attempt.Route) {
t.Fatalf("unexpected route returned")
t.Fatalf("unexpected route returned: %v vs %v",
spew.Sdump(attempt.Route), spew.Sdump(*route))
}
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
@ -427,7 +457,6 @@ func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) erro
r := bytes.NewReader(b)
c2, err := deserializePaymentCreationInfo(r)
if err != nil {
fmt.Println("creation info err: ", err)
return err
}
if !reflect.DeepEqual(c, c2) {
@ -454,11 +483,34 @@ func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error
if err != nil {
return err
}
err = assertRouteHopRecordsEqual(&a.Route, &a2.Route)
if err != nil {
return err
}
recordCache := make(map[int][]tlv.Record)
for i := 0; i < len(a.Route.Hops); i++ {
recordCache[i] = a.Route.Hops[i].TLVRecords
// reflect.DeepEqual can't assert that two function closures
// are equal. The underlying tlv.Record uses function closures
// internally, so after we verify that the records match above
// manually, we unset these so we can use reflect.DeepEqual
// below.
a.Route.Hops[i].TLVRecords = nil
a2.Route.Hops[i].TLVRecords = nil
}
if !reflect.DeepEqual(a, a2) {
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
spew.Sdump(a), spew.Sdump(a2))
}
for index, records := range recordCache {
a.Route.Hops[index].TLVRecords = records
}
return nil
}

@ -10,10 +10,12 @@ import (
"time"
"github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire"
"github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
var (
@ -512,9 +514,47 @@ func serializeHop(w io.Writer, h *route.Hop) error {
return err
}
if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil {
return err
}
// For legacy payloads, we don't need to write any TLV records, so
// we'll write a zero indicating the our serialized TLV map has no
// records.
if h.LegacyPayload {
return WriteElements(w, uint32(0))
}
// Otherwise, we'll transform our slice of records into a map of the
// raw bytes, then serialize them in-line with a length (number of
// elements) prefix.
mapRecords, err := tlv.RecordsToMap(h.TLVRecords)
if err != nil {
return err
}
numRecords := uint32(len(mapRecords))
if err := WriteElements(w, numRecords); err != nil {
return err
}
for recordType, rawBytes := range mapRecords {
if err := WriteElements(w, recordType); err != nil {
return err
}
if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil {
return err
}
}
return nil
}
// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need
// to read/write a TLV stream larger than this.
const maxOnionPayloadSize = 1300
func deserializeHop(r io.Reader) (*route.Hop, error) {
h := &route.Hop{}
@ -530,6 +570,47 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
return nil, err
}
// TODO(roasbeef): change field to allow LegacyPayload false to be the
// legacy default?
err := binary.Read(r, byteOrder, &h.LegacyPayload)
if err != nil {
return nil, err
}
var numElements uint32
if err := ReadElements(r, &numElements); err != nil {
return nil, err
}
// If there're no elements, then we can return early.
if numElements == 0 {
return h, nil
}
tlvMap := make(map[uint64][]byte)
for i := uint32(0); i < numElements; i++ {
var tlvType uint64
if err := ReadElements(r, &tlvType); err != nil {
return nil, err
}
rawRecordBytes, err := wire.ReadVarBytes(
r, 0, maxOnionPayloadSize, "tlv",
)
if err != nil {
return nil, err
}
tlvMap[tlvType] = rawRecordBytes
}
tlvRecords, err := tlv.MapToRecords(tlvMap)
if err != nil {
return nil, err
}
h.TLVRecords = tlvRecords
return h, nil
}

@ -13,17 +13,32 @@ import (
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
)
var (
priv, _ = btcec.NewPrivateKey(btcec.S256())
pub = priv.PubKey()
testHop = &route.Hop{
tlvBytes = []byte{1, 2, 3}
tlvEncoder = tlv.StubEncoder(tlvBytes)
testHop1 = &route.Hop{
PubKeyBytes: route.NewVertex(pub),
ChannelID: 12345,
OutgoingTimeLock: 111,
AmtToForward: 555,
TLVRecords: []tlv.Record{
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
},
}
testHop2 = &route.Hop{
PubKeyBytes: route.NewVertex(pub),
ChannelID: 12345,
OutgoingTimeLock: 111,
AmtToForward: 555,
LegacyPayload: true,
}
testRoute = route.Route{
@ -31,8 +46,8 @@ var (
TotalAmount: 1234567,
SourcePubKey: route.NewVertex(pub),
Hops: []*route.Hop{
testHop,
testHop,
testHop1,
testHop2,
},
}
)
@ -191,6 +206,8 @@ func TestSentPaymentSerialization(t *testing.T) {
}
if !reflect.DeepEqual(s, newAttemptInfo) {
s.SessionKey.Curve = nil
newAttemptInfo.SessionKey.Curve = nil
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(s), spew.Sdump(newAttemptInfo),
@ -199,6 +216,46 @@ func TestSentPaymentSerialization(t *testing.T) {
}
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
for i := 0; i < len(r1.Hops); i++ {
for j := 0; j < len(r1.Hops[i].TLVRecords); j++ {
expectedRecord := r1.Hops[i].TLVRecords[j]
newRecord := r2.Hops[i].TLVRecords[j]
err := assertHopRecordsEqual(expectedRecord, newRecord)
if err != nil {
return fmt.Errorf("route record mismatch: %v", err)
}
}
}
return nil
}
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
if h1.Type() != h2.Type() {
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
h2.Type())
}
var b bytes.Buffer
if err := h2.Encode(&b); err != nil {
return fmt.Errorf("unable to encode record: %v", err)
}
if !bytes.Equal(b.Bytes(), tlvBytes) {
return fmt.Errorf("wrong raw record: expected %x, got %x",
tlvBytes, b.Bytes())
}
if h1.Size() != h2.Size() {
return fmt.Errorf("wrong size: expected %v, "+
"got %v", h1.Size(), h2.Size())
}
return nil
}
func TestRouteSerialization(t *testing.T) {
t.Parallel()
@ -213,9 +270,23 @@ func TestRouteSerialization(t *testing.T) {
t.Fatal(err)
}
// First we verify all the records match up porperly, as they aren't
// able to be properly compared using reflect.DeepEqual.
err = assertRouteHopRecordsEqual(&testRoute, &route2)
if err != nil {
t.Fatalf("route tlv records don't match: %v", err)
}
// Now that we know the records match up, we'll examine the remainder
// of the route without the TLV records attached as reflect.DeepEqual
// can't properly assert their equality.
testRoute.Hops[0].TLVRecords = nil
testRoute.Hops[1].TLVRecords = nil
route2.Hops[0].TLVRecords = nil
route2.Hops[1].TLVRecords = nil
if !reflect.DeepEqual(testRoute, route2) {
t.Fatalf("routes not equal: \n%v vs \n%v",
spew.Sdump(testRoute), spew.Sdump(route2))
}
}