diff --git a/channeldb/payments.go b/channeldb/payments.go index e51f2ce1..213bd824 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -8,6 +8,7 @@ import ( "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" ) var ( @@ -328,3 +329,84 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { return p, nil } + +func serializeHop(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + return nil +} + +func deserializeHop(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + return h, nil +} + +func serializeRoute(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalFees, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHop(w, h); err != nil { + return err + } + } + + return nil +} + +func deserializeRoute(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalFees, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHop(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index d13e039d..bea87d17 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -8,8 +8,33 @@ import ( "testing" "time" + "github.com/btcsuite/btcd/btcec" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + priv, _ = btcec.NewPrivateKey(btcec.S256()) + pub = priv.PubKey() + + testHop = &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + } + + testRoute = route.Route{ + TotalTimeLock: 123, + TotalFees: 999, + TotalAmount: 1234567, + SourcePubKey: route.NewVertex(pub), + Hops: []*route.Hop{ + testHop, + testHop, + }, + } ) func makeFakePayment() *OutgoingPayment { @@ -251,3 +276,24 @@ func TestPaymentStatusWorkflow(t *testing.T) { } } } + +func TestRouteSerialization(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + if err := serializeRoute(&b, testRoute); err != nil { + t.Fatal(err) + } + + r := bytes.NewReader(b.Bytes()) + route2, err := deserializeRoute(r) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(testRoute, route2) { + t.Fatalf("routes not equal: \n%v vs \n%v", + spew.Sdump(testRoute), spew.Sdump(route2)) + } + +}