diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index 5cfce845..02359d6d 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -2,7 +2,9 @@ package htlcswitch import ( "errors" + "io" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" ) @@ -46,3 +48,34 @@ type networkResult struct { // which the failure reason might not be included. isResolution bool } + +// serializeNetworkResult serializes the networkResult. +func serializeNetworkResult(w io.Writer, n *networkResult) error { + if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil { + return err + } + + return channeldb.WriteElements(w, n.unencrypted, n.isResolution) +} + +// deserializeNetworkResult deserializes the networkResult. +func deserializeNetworkResult(r io.Reader) (*networkResult, error) { + var ( + err error + ) + + n := &networkResult{} + + n.msg, err = lnwire.ReadMessage(r, 0) + if err != nil { + return nil, err + } + + if err := channeldb.ReadElements(r, + &n.unencrypted, &n.isResolution, + ); err != nil { + return nil, err + } + + return n, nil +} diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go new file mode 100644 index 00000000..4b45bc9a --- /dev/null +++ b/htlcswitch/payment_result_test.go @@ -0,0 +1,90 @@ +package htlcswitch + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" +) + +// TestNetworkResultSerialization checks that NetworkResults are properly +// (de)serialized. +func TestNetworkResultSerialization(t *testing.T) { + t.Parallel() + + var preimage lntypes.Preimage + if _, err := rand.Read(preimage[:]); err != nil { + t.Fatalf("unable gen rand preimag: %v", err) + } + + var chanID lnwire.ChannelID + if _, err := rand.Read(chanID[:]); err != nil { + t.Fatalf("unable gen rand chanid: %v", err) + } + + var reason [256]byte + if _, err := rand.Read(reason[:]); err != nil { + t.Fatalf("unable gen rand reason: %v", err) + } + + settle := &lnwire.UpdateFulfillHTLC{ + ChanID: chanID, + ID: 2, + PaymentPreimage: preimage, + } + + fail := &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: 1, + Reason: []byte{}, + } + + fail2 := &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: 1, + Reason: reason[:], + } + + testCases := []*networkResult{ + { + msg: settle, + }, + { + msg: fail, + unencrypted: false, + isResolution: false, + }, + { + msg: fail, + unencrypted: false, + isResolution: true, + }, + { + msg: fail2, + unencrypted: true, + isResolution: false, + }, + } + + for _, p := range testCases { + var buf bytes.Buffer + if err := serializeNetworkResult(&buf, p); err != nil { + t.Fatalf("serialize failed: %v", err) + } + + r := bytes.NewReader(buf.Bytes()) + p1, err := deserializeNetworkResult(r) + if err != nil { + t.Fatalf("unable to deserizlize: %v", err) + } + + if !reflect.DeepEqual(p, p1) { + t.Fatalf("not equal. %v vs %v", spew.Sdump(p), + spew.Sdump(p1)) + } + } +}