channeldb: update route.Hop serialization to include new EOB related fields
We also include a migration for the existing routes stored on disk.
This commit is contained in:
parent
763cb6c09d
commit
c78e3aaa9d
@ -103,6 +103,13 @@ var (
|
||||
number: 9,
|
||||
migration: migrateOutgoingPayments,
|
||||
},
|
||||
{
|
||||
// The DB version where we started to store legacy
|
||||
// payload information for all routes, as well as the
|
||||
// optional TLV records.
|
||||
number: 10,
|
||||
migration: migrateRouteSerialization,
|
||||
},
|
||||
}
|
||||
|
||||
// Big endian is the preferred byte order, due to cursor scans over
|
||||
|
236
channeldb/migration_10_route_tlv_records.go
Normal file
236
channeldb/migration_10_route_tlv_records.go
Normal file
@ -0,0 +1,236 @@
|
||||
package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
)
|
||||
|
||||
// migrateRouteSerialization migrates the way we serialize routes across the
|
||||
// entire database. At the time of writing of this migration, this includes our
|
||||
// payment attempts, as well as the payment results in mission control.
|
||||
func migrateRouteSerialization(tx *bbolt.Tx) error {
|
||||
// First, we'll do all the payment attempts.
|
||||
rootPaymentBucket := tx.Bucket(paymentsRootBucket)
|
||||
if rootPaymentBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// As we can't mutate a bucket while we're iterating over it with
|
||||
// ForEach, we'll need to collect all the known payment hashes in
|
||||
// memory first.
|
||||
var payHashes [][]byte
|
||||
err := rootPaymentBucket.ForEach(func(k, v []byte) error {
|
||||
if v != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
payHashes = append(payHashes, k)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now that we have all the payment hashes, we can carry out the
|
||||
// migration itself.
|
||||
for _, payHash := range payHashes {
|
||||
payHashBucket := rootPaymentBucket.Bucket(payHash)
|
||||
|
||||
// First, we'll migrate the main (non duplicate) payment to
|
||||
// this hash.
|
||||
err := migrateAttemptEncoding(tx, payHashBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now that we've migrated the main payment, we'll also check
|
||||
// for any duplicate payments to the same payment hash.
|
||||
dupBucket := payHashBucket.Bucket(paymentDuplicateBucket)
|
||||
|
||||
// If there's no dup bucket, then we can move on to the next
|
||||
// payment.
|
||||
if dupBucket == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, we'll now iterate through all the duplicate pay
|
||||
// hashes and migrate those.
|
||||
var dupSeqNos [][]byte
|
||||
err = dupBucket.ForEach(func(k, v []byte) error {
|
||||
dupSeqNos = append(dupSeqNos, k)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now in this second pass, we'll re-serialize their duplicate
|
||||
// payment attempts under the new encoding.
|
||||
for _, seqNo := range dupSeqNos {
|
||||
dupPayHashBucket := dupBucket.Bucket(seqNo)
|
||||
err := migrateAttemptEncoding(tx, dupPayHashBucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Migration of route/hop serialization complete!")
|
||||
|
||||
log.Infof("Migrating to new mission control store by clearing " +
|
||||
"existing data")
|
||||
|
||||
resultsKey := []byte("missioncontrol-results")
|
||||
err = tx.DeleteBucket(resultsKey)
|
||||
if err != nil && err != bbolt.ErrBucketNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Migration to new mission control completed!")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateAttemptEncoding migrates payment attempts using the legacy format to
|
||||
// the new format.
|
||||
func migrateAttemptEncoding(tx *bbolt.Tx, payHashBucket *bbolt.Bucket) error {
|
||||
payAttemptBytes := payHashBucket.Get(paymentAttemptInfoKey)
|
||||
if payAttemptBytes == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For our migration, we'll first read out the existing payment attempt
|
||||
// using the legacy serialization of the attempt.
|
||||
payAttemptReader := bytes.NewReader(payAttemptBytes)
|
||||
payAttempt, err := deserializePaymentAttemptInfoLegacy(
|
||||
payAttemptReader,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now that we have the old attempts, we'll explicitly mark this as
|
||||
// needing a legacy payload, since after this migration, the modern
|
||||
// payload will be the default if signalled.
|
||||
for _, hop := range payAttempt.Route.Hops {
|
||||
hop.LegacyPayload = true
|
||||
}
|
||||
|
||||
// Finally, we'll write out the payment attempt using the new encoding.
|
||||
var b bytes.Buffer
|
||||
err = serializePaymentAttemptInfo(&b, payAttempt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return payHashBucket.Put(paymentAttemptInfoKey, b.Bytes())
|
||||
}
|
||||
|
||||
func deserializePaymentAttemptInfoLegacy(r io.Reader) (*PaymentAttemptInfo, error) {
|
||||
a := &PaymentAttemptInfo{}
|
||||
err := ReadElements(r, &a.PaymentID, &a.SessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.Route, err = deserializeRouteLegacy(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func serializePaymentAttemptInfoLegacy(w io.Writer, a *PaymentAttemptInfo) error {
|
||||
if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := serializeRouteLegacy(w, a.Route); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deserializeHopLegacy(r io.Reader) (*route.Hop, error) {
|
||||
h := &route.Hop{}
|
||||
|
||||
var pub []byte
|
||||
if err := ReadElements(r, &pub); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(h.PubKeyBytes[:], pub)
|
||||
|
||||
if err := ReadElements(r,
|
||||
&h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func serializeHopLegacy(w io.Writer, h *route.Hop) error {
|
||||
if err := WriteElements(w,
|
||||
h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock,
|
||||
h.AmtToForward,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deserializeRouteLegacy(r io.Reader) (route.Route, error) {
|
||||
rt := route.Route{}
|
||||
if err := ReadElements(r,
|
||||
&rt.TotalTimeLock, &rt.TotalAmount,
|
||||
); err != nil {
|
||||
return rt, err
|
||||
}
|
||||
|
||||
var pub []byte
|
||||
if err := ReadElements(r, &pub); err != nil {
|
||||
return rt, err
|
||||
}
|
||||
copy(rt.SourcePubKey[:], pub)
|
||||
|
||||
var numHops uint32
|
||||
if err := ReadElements(r, &numHops); err != nil {
|
||||
return rt, err
|
||||
}
|
||||
|
||||
var hops []*route.Hop
|
||||
for i := uint32(0); i < numHops; i++ {
|
||||
hop, err := deserializeHopLegacy(r)
|
||||
if err != nil {
|
||||
return rt, err
|
||||
}
|
||||
hops = append(hops, hop)
|
||||
}
|
||||
rt.Hops = hops
|
||||
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
func serializeRouteLegacy(w io.Writer, r route.Route) error {
|
||||
if err := WriteElements(w,
|
||||
r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:],
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := WriteElements(w, uint32(len(r.Hops))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, h := range r.Hops {
|
||||
if err := serializeHopLegacy(w, h); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -5,14 +5,18 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcutil"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
)
|
||||
|
||||
// TestPaymentStatusesMigration checks that already completed payments will have
|
||||
@ -723,3 +727,223 @@ func TestOutgoingPaymentsMigration(t *testing.T) {
|
||||
migrateOutgoingPayments,
|
||||
false)
|
||||
}
|
||||
|
||||
func makeRandPaymentCreationInfo() (*PaymentCreationInfo, error) {
|
||||
var payHash lntypes.Hash
|
||||
if _, err := rand.Read(payHash[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &PaymentCreationInfo{
|
||||
PaymentHash: payHash,
|
||||
Value: lnwire.MilliSatoshi(rand.Int63()),
|
||||
CreationDate: time.Now(),
|
||||
PaymentRequest: []byte("test"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TestPaymentRouteSerialization tests that we're able to properly migrate
|
||||
// existing payments on disk that contain the traversed routes to the new
|
||||
// routing format which supports the TLV payloads. We also test that the
|
||||
// migration is able to handle duplicate payment attempts.
|
||||
func TestPaymentRouteSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
legacyHop1 := &route.Hop{
|
||||
PubKeyBytes: route.NewVertex(pub),
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
LegacyPayload: true,
|
||||
AmtToForward: 555,
|
||||
}
|
||||
legacyHop2 := &route.Hop{
|
||||
PubKeyBytes: route.NewVertex(pub),
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
LegacyPayload: true,
|
||||
AmtToForward: 555,
|
||||
}
|
||||
legacyRoute := route.Route{
|
||||
TotalTimeLock: 123,
|
||||
TotalAmount: 1234567,
|
||||
SourcePubKey: route.NewVertex(pub),
|
||||
Hops: []*route.Hop{legacyHop1, legacyHop2},
|
||||
}
|
||||
|
||||
const numPayments = 4
|
||||
var oldPayments []*Payment
|
||||
|
||||
sharedPayAttempt := PaymentAttemptInfo{
|
||||
PaymentID: 1,
|
||||
SessionKey: priv,
|
||||
Route: legacyRoute,
|
||||
}
|
||||
|
||||
// We'll first add a series of fake payments, using the existing legacy
|
||||
// serialization format.
|
||||
beforeMigrationFunc := func(d *DB) {
|
||||
err := d.Update(func(tx *bbolt.Tx) error {
|
||||
paymentsBucket, err := tx.CreateBucket(
|
||||
paymentsRootBucket,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create new payments "+
|
||||
"bucket: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < numPayments; i++ {
|
||||
var seqNum [8]byte
|
||||
byteOrder.PutUint64(seqNum[:], uint64(i))
|
||||
|
||||
// All payments will be randomly generated,
|
||||
// other than the final payment. We'll force
|
||||
// the final payment to re-use an existing
|
||||
// payment hash so we can insert it into the
|
||||
// duplicate payment hash bucket.
|
||||
var payInfo *PaymentCreationInfo
|
||||
if i < numPayments-1 {
|
||||
payInfo, err = makeRandPaymentCreationInfo()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create "+
|
||||
"payment: %v", err)
|
||||
}
|
||||
} else {
|
||||
payInfo = oldPayments[0].Info
|
||||
}
|
||||
|
||||
// Next, legacy encoded when needed, we'll
|
||||
// serialize the info and the attempt.
|
||||
var payInfoBytes bytes.Buffer
|
||||
err = serializePaymentCreationInfo(
|
||||
&payInfoBytes, payInfo,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode pay "+
|
||||
"info: %v", err)
|
||||
}
|
||||
var payAttemptBytes bytes.Buffer
|
||||
err = serializePaymentAttemptInfoLegacy(
|
||||
&payAttemptBytes, &sharedPayAttempt,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to encode payment attempt: "+
|
||||
"%v", err)
|
||||
}
|
||||
|
||||
// Before we write to disk, we'll need to fetch
|
||||
// the proper bucket. If this is the duplicate
|
||||
// payment, then we'll grab the dup bucket,
|
||||
// otherwise, we'll use the top level bucket.
|
||||
var payHashBucket *bbolt.Bucket
|
||||
if i < numPayments-1 {
|
||||
payHashBucket, err = paymentsBucket.CreateBucket(
|
||||
payInfo.PaymentHash[:],
|
||||
)
|
||||
} else {
|
||||
payHashBucket = paymentsBucket.Bucket(
|
||||
payInfo.PaymentHash[:],
|
||||
)
|
||||
dupPayBucket, err := payHashBucket.CreateBucket(
|
||||
paymentDuplicateBucket,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create "+
|
||||
"dup hash bucket: %v", err)
|
||||
}
|
||||
|
||||
payHashBucket, err = dupPayBucket.CreateBucket(
|
||||
seqNum[:],
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make dup "+
|
||||
"bucket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = payHashBucket.Put(paymentSequenceKey, seqNum[:])
|
||||
if err != nil {
|
||||
t.Fatalf("unable to write seqno: %v", err)
|
||||
}
|
||||
|
||||
err = payHashBucket.Put(
|
||||
paymentCreationInfoKey, payInfoBytes.Bytes(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to write creation "+
|
||||
"info: %v", err)
|
||||
}
|
||||
|
||||
err = payHashBucket.Put(
|
||||
paymentAttemptInfoKey, payAttemptBytes.Bytes(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to write attempt "+
|
||||
"info: %v", err)
|
||||
}
|
||||
|
||||
oldPayments = append(oldPayments, &Payment{
|
||||
Info: payInfo,
|
||||
Attempt: &sharedPayAttempt,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create test payments: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
afterMigrationFunc := func(d *DB) {
|
||||
newPayments, err := d.FetchPayments()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to fetch new payments: %v", err)
|
||||
}
|
||||
|
||||
if len(newPayments) != numPayments {
|
||||
t.Fatalf("expected %d payments, got %d", numPayments,
|
||||
len(newPayments))
|
||||
}
|
||||
|
||||
for i, p := range newPayments {
|
||||
// Order of payments should be be preserved.
|
||||
old := oldPayments[i]
|
||||
|
||||
if p.Attempt.PaymentID != old.Attempt.PaymentID {
|
||||
t.Fatalf("wrong pay ID: expected %v, got %v",
|
||||
p.Attempt.PaymentID,
|
||||
old.Attempt.PaymentID)
|
||||
}
|
||||
|
||||
if p.Attempt.Route.TotalFees() != old.Attempt.Route.TotalFees() {
|
||||
t.Fatalf("Fee mismatch")
|
||||
}
|
||||
|
||||
if p.Attempt.Route.TotalAmount != old.Attempt.Route.TotalAmount {
|
||||
t.Fatalf("Total amount mismatch")
|
||||
}
|
||||
|
||||
if p.Attempt.Route.TotalTimeLock != old.Attempt.Route.TotalTimeLock {
|
||||
t.Fatalf("timelock mismatch")
|
||||
}
|
||||
|
||||
if p.Attempt.Route.SourcePubKey != old.Attempt.Route.SourcePubKey {
|
||||
t.Fatalf("source mismatch: %x vs %x",
|
||||
p.Attempt.Route.SourcePubKey[:],
|
||||
old.Attempt.Route.SourcePubKey[:])
|
||||
}
|
||||
|
||||
for i, hop := range p.Attempt.Route.Hops {
|
||||
if !reflect.DeepEqual(hop, legacyRoute.Hops[i]) {
|
||||
t.Fatalf("hop mismatch")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
applyMigration(t,
|
||||
beforeMigrationFunc,
|
||||
afterMigrationFunc,
|
||||
migrateRouteSerialization,
|
||||
false)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
func initDB() (*DB, error) {
|
||||
@ -137,8 +138,37 @@ func TestPaymentControlSwitchFail(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("error shouldn't have been received, got: %v", err)
|
||||
}
|
||||
|
||||
err = assertRouteHopRecordsEqual(route, &attempt.Route)
|
||||
if err != nil {
|
||||
t.Fatalf("route tlv records not equal: %v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < len(route.Hops); i++ {
|
||||
for j := 0; j < len(route.Hops[i].TLVRecords); j++ {
|
||||
expectedRecord := route.Hops[i].TLVRecords[j]
|
||||
newRecord := attempt.Route.Hops[i].TLVRecords[j]
|
||||
|
||||
err := assertHopRecordsEqual(expectedRecord, newRecord)
|
||||
if err != nil {
|
||||
t.Fatalf("route record mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(route.Hops); i++ {
|
||||
// reflect.DeepEqual can't assert that two function closures
|
||||
// are equal. The underlying tlv.Record uses function closures
|
||||
// internally, so after we verify that the records match above
|
||||
// manually, we unset these so we can use reflect.DeepEqual
|
||||
// below.
|
||||
route.Hops[i].TLVRecords = nil
|
||||
attempt.Route.Hops[i].TLVRecords = nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(*route, attempt.Route) {
|
||||
t.Fatalf("unexpected route returned")
|
||||
t.Fatalf("unexpected route returned: %v vs %v",
|
||||
spew.Sdump(attempt.Route), spew.Sdump(*route))
|
||||
}
|
||||
|
||||
assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded)
|
||||
@ -427,7 +457,6 @@ func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) erro
|
||||
r := bytes.NewReader(b)
|
||||
c2, err := deserializePaymentCreationInfo(r)
|
||||
if err != nil {
|
||||
fmt.Println("creation info err: ", err)
|
||||
return err
|
||||
}
|
||||
if !reflect.DeepEqual(c, c2) {
|
||||
@ -454,11 +483,34 @@ func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = assertRouteHopRecordsEqual(&a.Route, &a2.Route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordCache := make(map[int][]tlv.Record)
|
||||
for i := 0; i < len(a.Route.Hops); i++ {
|
||||
recordCache[i] = a.Route.Hops[i].TLVRecords
|
||||
|
||||
// reflect.DeepEqual can't assert that two function closures
|
||||
// are equal. The underlying tlv.Record uses function closures
|
||||
// internally, so after we verify that the records match above
|
||||
// manually, we unset these so we can use reflect.DeepEqual
|
||||
// below.
|
||||
a.Route.Hops[i].TLVRecords = nil
|
||||
a2.Route.Hops[i].TLVRecords = nil
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a, a2) {
|
||||
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
|
||||
spew.Sdump(a), spew.Sdump(a2))
|
||||
}
|
||||
|
||||
for index, records := range recordCache {
|
||||
a.Route.Hops[index].TLVRecords = records
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -10,10 +10,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/coreos/bbolt"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -512,9 +514,47 @@ func serializeHop(w io.Writer, h *route.Hop) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// For legacy payloads, we don't need to write any TLV records, so
|
||||
// we'll write a zero indicating the our serialized TLV map has no
|
||||
// records.
|
||||
if h.LegacyPayload {
|
||||
return WriteElements(w, uint32(0))
|
||||
}
|
||||
|
||||
// Otherwise, we'll transform our slice of records into a map of the
|
||||
// raw bytes, then serialize them in-line with a length (number of
|
||||
// elements) prefix.
|
||||
mapRecords, err := tlv.RecordsToMap(h.TLVRecords)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
numRecords := uint32(len(mapRecords))
|
||||
if err := WriteElements(w, numRecords); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for recordType, rawBytes := range mapRecords {
|
||||
if err := WriteElements(w, recordType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need
|
||||
// to read/write a TLV stream larger than this.
|
||||
const maxOnionPayloadSize = 1300
|
||||
|
||||
func deserializeHop(r io.Reader) (*route.Hop, error) {
|
||||
h := &route.Hop{}
|
||||
|
||||
@ -530,6 +570,47 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(roasbeef): change field to allow LegacyPayload false to be the
|
||||
// legacy default?
|
||||
err := binary.Read(r, byteOrder, &h.LegacyPayload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var numElements uint32
|
||||
if err := ReadElements(r, &numElements); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If there're no elements, then we can return early.
|
||||
if numElements == 0 {
|
||||
return h, nil
|
||||
}
|
||||
|
||||
tlvMap := make(map[uint64][]byte)
|
||||
for i := uint32(0); i < numElements; i++ {
|
||||
var tlvType uint64
|
||||
if err := ReadElements(r, &tlvType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rawRecordBytes, err := wire.ReadVarBytes(
|
||||
r, 0, maxOnionPayloadSize, "tlv",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlvMap[tlvType] = rawRecordBytes
|
||||
}
|
||||
|
||||
tlvRecords, err := tlv.MapToRecords(tlvMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.TLVRecords = tlvRecords
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
|
@ -13,17 +13,32 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
priv, _ = btcec.NewPrivateKey(btcec.S256())
|
||||
pub = priv.PubKey()
|
||||
|
||||
testHop = &route.Hop{
|
||||
tlvBytes = []byte{1, 2, 3}
|
||||
tlvEncoder = tlv.StubEncoder(tlvBytes)
|
||||
testHop1 = &route.Hop{
|
||||
PubKeyBytes: route.NewVertex(pub),
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
AmtToForward: 555,
|
||||
TLVRecords: []tlv.Record{
|
||||
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil),
|
||||
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil),
|
||||
},
|
||||
}
|
||||
|
||||
testHop2 = &route.Hop{
|
||||
PubKeyBytes: route.NewVertex(pub),
|
||||
ChannelID: 12345,
|
||||
OutgoingTimeLock: 111,
|
||||
AmtToForward: 555,
|
||||
LegacyPayload: true,
|
||||
}
|
||||
|
||||
testRoute = route.Route{
|
||||
@ -31,8 +46,8 @@ var (
|
||||
TotalAmount: 1234567,
|
||||
SourcePubKey: route.NewVertex(pub),
|
||||
Hops: []*route.Hop{
|
||||
testHop,
|
||||
testHop,
|
||||
testHop1,
|
||||
testHop2,
|
||||
},
|
||||
}
|
||||
)
|
||||
@ -191,6 +206,8 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(s, newAttemptInfo) {
|
||||
s.SessionKey.Curve = nil
|
||||
newAttemptInfo.SessionKey.Curve = nil
|
||||
t.Fatalf("Payments do not match after "+
|
||||
"serialization/deserialization %v vs %v",
|
||||
spew.Sdump(s), spew.Sdump(newAttemptInfo),
|
||||
@ -199,6 +216,46 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
|
||||
for i := 0; i < len(r1.Hops); i++ {
|
||||
for j := 0; j < len(r1.Hops[i].TLVRecords); j++ {
|
||||
expectedRecord := r1.Hops[i].TLVRecords[j]
|
||||
newRecord := r2.Hops[i].TLVRecords[j]
|
||||
|
||||
err := assertHopRecordsEqual(expectedRecord, newRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("route record mismatch: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
|
||||
if h1.Type() != h2.Type() {
|
||||
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
|
||||
h2.Type())
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := h2.Encode(&b); err != nil {
|
||||
return fmt.Errorf("unable to encode record: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(b.Bytes(), tlvBytes) {
|
||||
return fmt.Errorf("wrong raw record: expected %x, got %x",
|
||||
tlvBytes, b.Bytes())
|
||||
}
|
||||
|
||||
if h1.Size() != h2.Size() {
|
||||
return fmt.Errorf("wrong size: expected %v, "+
|
||||
"got %v", h1.Size(), h2.Size())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRouteSerialization(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -213,9 +270,23 @@ func TestRouteSerialization(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// First we verify all the records match up porperly, as they aren't
|
||||
// able to be properly compared using reflect.DeepEqual.
|
||||
err = assertRouteHopRecordsEqual(&testRoute, &route2)
|
||||
if err != nil {
|
||||
t.Fatalf("route tlv records don't match: %v", err)
|
||||
}
|
||||
|
||||
// Now that we know the records match up, we'll examine the remainder
|
||||
// of the route without the TLV records attached as reflect.DeepEqual
|
||||
// can't properly assert their equality.
|
||||
testRoute.Hops[0].TLVRecords = nil
|
||||
testRoute.Hops[1].TLVRecords = nil
|
||||
route2.Hops[0].TLVRecords = nil
|
||||
route2.Hops[1].TLVRecords = nil
|
||||
|
||||
if !reflect.DeepEqual(testRoute, route2) {
|
||||
t.Fatalf("routes not equal: \n%v vs \n%v",
|
||||
spew.Sdump(testRoute), spew.Sdump(route2))
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user