package record_test import ( "bytes" "testing" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/tlv" ) type recordEncDecTest struct { name string encRecord func() tlv.RecordProducer decRecord func() tlv.RecordProducer assert func(*testing.T, interface{}) } var ( testTotal = lnwire.MilliSatoshi(45) testAddr = [32]byte{0x01, 0x02} testShare = [32]byte{0x03, 0x04} testSetID = [32]byte{0x05, 0x06} testChildIndex = uint16(17) ) var recordEncDecTests = []recordEncDecTest{ { name: "mpp", encRecord: func() tlv.RecordProducer { return record.NewMPP(testTotal, testAddr) }, decRecord: func() tlv.RecordProducer { return new(record.MPP) }, assert: func(t *testing.T, r interface{}) { mpp := r.(*record.MPP) if mpp.TotalMsat() != testTotal { t.Fatal("incorrect total msat") } if mpp.PaymentAddr() != testAddr { t.Fatal("incorrect payment addr") } }, }, { name: "amp", encRecord: func() tlv.RecordProducer { return record.NewAMP( testShare, testSetID, testChildIndex, ) }, decRecord: func() tlv.RecordProducer { return new(record.AMP) }, assert: func(t *testing.T, r interface{}) { amp := r.(*record.AMP) if amp.RootShare() != testShare { t.Fatal("incorrect root share") } if amp.SetID() != testSetID { t.Fatal("incorrect set id") } if amp.ChildIndex() != testChildIndex { t.Fatal("incorrect child index") } }, }, } // TestRecordEncodeDecode is a generic test framework for custom TLV records. It // asserts that records can encode and decode themselves, and that the value of // the original record matches the decoded record. func TestRecordEncodeDecode(t *testing.T) { for _, test := range recordEncDecTests { test := test t.Run(test.name, func(t *testing.T) { r := test.encRecord() r2 := test.decRecord() encStream := tlv.MustNewStream(r.Record()) decStream := tlv.MustNewStream(r2.Record()) test.assert(t, r) var b bytes.Buffer err := encStream.Encode(&b) if err != nil { t.Fatalf("unable to encode record: %v", err) } err = decStream.Decode(bytes.NewReader(b.Bytes())) if err != nil { t.Fatalf("unable to decode record: %v", err) } test.assert(t, r2) }) } }