diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go new file mode 100644 index 00000000..22fd20bd --- /dev/null +++ b/lnwire/extra_bytes.go @@ -0,0 +1,84 @@ +package lnwire + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/lightningnetwork/lnd/tlv" +) + +// ExtraOpaqueData is the set of data that was appended to this message, some +// of which we may not actually know how to iterate or parse. By holding onto +// this data, we ensure that we're able to properly validate the set of +// signatures that cover these new fields, and ensure we're able to make +// upgrades to the network in a forwards compatible manner. +type ExtraOpaqueData []byte + +// Encode attempts to encode the raw extra bytes into the passed io.Writer. +func (e *ExtraOpaqueData) Encode(w io.Writer) error { + eBytes := []byte((*e)[:]) + if err := WriteElements(w, eBytes); err != nil { + return err + } + + return nil +} + +// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a +// set of extra opaque data. +func (e *ExtraOpaqueData) Decode(r io.Reader) error { + // First, we'll attempt to read a set of bytes contained within the + // passed io.Reader (if any exist). + rawBytes, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + // If we _do_ have some bytes, then we'll swap out our backing pointer. + // This ensures that any struct that embeds this type will properly + // store the bytes once this method exits. + if len(rawBytes) > 0 { + *e = ExtraOpaqueData(rawBytes) + } else { + *e = make([]byte, 0) + } + + return nil +} + +// PackRecords attempts to encode the set of tlv records into the target +// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream +// and stored within the backing slice pointer. +func (e *ExtraOpaqueData) PackRecords(records ...tlv.Record) error { + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + var extraBytesWriter bytes.Buffer + if err := tlvStream.Encode(&extraBytesWriter); err != nil { + return err + } + + *e = ExtraOpaqueData(extraBytesWriter.Bytes()) + + return nil +} + +// ExtractRecords attempts to decode any types in the internal raw bytes as if +// it were a tlv stream. The set of raw parsed types is returned, and any +// passed records (if found in the stream) will be parsed into the proper +// tlv.Record. +func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) ( + tlv.TypeMap, error) { + + extraBytesReader := bytes.NewReader(*e) + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypes(extraBytesReader) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go new file mode 100644 index 00000000..39271d6a --- /dev/null +++ b/lnwire/extra_bytes_test.go @@ -0,0 +1,147 @@ +package lnwire + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode +// arbitrary payloads. +func TestExtraOpaqueDataEncodeDecode(t *testing.T) { + t.Parallel() + + type testCase struct { + // emptyBytes indicates if we should try to encode empty bytes + // or not. + emptyBytes bool + + // inputBytes if emptyBytes is false, then we'll read in this + // set of bytes instead. + inputBytes []byte + } + + // We should be able to read in an arbitrary set of bytes as an + // ExtraOpaqueData, then encode those new bytes into a new instance. + // The final two instances should be identical. + scenario := func(test testCase) bool { + var ( + extraData ExtraOpaqueData + b bytes.Buffer + ) + + copy(extraData[:], test.inputBytes) + + if err := extraData.Encode(&b); err != nil { + t.Fatalf("unable to encode extra data: %v", err) + return false + } + + var newBytes ExtraOpaqueData + if err := newBytes.Decode(&b); err != nil { + t.Fatalf("unable to decode extra bytes: %v", err) + return false + } + + if !bytes.Equal(extraData[:], newBytes[:]) { + t.Fatalf("expected %x, got %x", extraData, + newBytes) + return false + } + + return true + } + + // We'll make a function to generate random test data. Half of the + // time, we'll actually feed in blank bytes. + quickCfg := &quick.Config{ + Values: func(v []reflect.Value, r *rand.Rand) { + + var newTestCase testCase + if r.Int31()%2 == 0 { + newTestCase.emptyBytes = true + } + + if !newTestCase.emptyBytes { + numBytes := r.Int31n(1000) + newTestCase.inputBytes = make([]byte, numBytes) + + _, err := r.Read(newTestCase.inputBytes) + if err != nil { + t.Fatalf("unable to gen random bytes: %v", err) + return + } + } + + v[0] = reflect.ValueOf(newTestCase) + }, + } + + if err := quick.Check(scenario, quickCfg); err != nil { + t.Fatalf("encode+decode test failed: %v", err) + } +} + +// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of +// tlv.Records into a stream, and unpack them on the other side to obtain the +// same set of records. +func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { + t.Parallel() + + var ( + type1 tlv.Type = 1 + type2 tlv.Type = 2 + + channelType1 uint8 = 2 + channelType2 uint8 + + hop1 uint32 = 99 + hop2 uint32 + ) + testRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType1), + tlv.MakePrimitiveRecord(type2, &hop1), + } + + // Now that we have our set of sample records and types, we'll encode + // them into the passed ExtraOpaqueData instance. + var extraBytes ExtraOpaqueData + if err := extraBytes.PackRecords(testRecords...); err != nil { + t.Fatalf("unable to pack records: %v", err) + } + + // We'll now simulate decoding these types _back_ into records on the + // other side. + newRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType2), + tlv.MakePrimitiveRecord(type2, &hop2), + } + typeMap, err := extraBytes.ExtractRecords(newRecords...) + if err != nil { + t.Fatalf("unable to extract record: %v", err) + } + + // We should find that the new backing values have been populated with + // the proper value. + switch { + case channelType1 != channelType2: + t.Fatalf("wrong record for channel type: expected %v, got %v", + channelType1, channelType2) + + case hop1 != hop2: + t.Fatalf("wrong record for hop: expected %v, got %v", hop1, + hop2) + } + + // Both types we created above should be found in the type map. + if _, ok := typeMap[type1]; !ok { + t.Fatalf("type1 not found in typeMap") + } + if _, ok := typeMap[type2]; !ok { + t.Fatalf("type2 not found in typeMap") + } +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index ca0e449e..c180cad3 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -18,9 +18,16 @@ import ( "github.com/lightningnetwork/lnd/tor" ) -// MaxSliceLength is the maximum allowed length for any opaque byte slices in -// the wire protocol. -const MaxSliceLength = 65535 +const ( + // MaxSliceLength is the maximum allowed length for any opaque byte + // slices in the wire protocol. + MaxSliceLength = 65535 + + // MaxMsgBody is the largest payload any message is allowed to provide. + // This is two less than the MaxSliceLength as each message has a 2 + // byte type that precedes the message body. + MaxMsgBody = 65533 +) // PkScript is simple type definition which represents a raw serialized public // key script. @@ -418,6 +425,10 @@ func WriteElement(w io.Writer, element interface{}) error { if _, err := w.Write(b[:]); err != nil { return err } + + case ExtraOpaqueData: + return e.Encode(w) + default: return fmt.Errorf("unknown type in WriteElement: %T", e) } @@ -824,6 +835,10 @@ func ReadElement(r io.Reader, element interface{}) error { return err } *e = addrBytes[:length] + + case *ExtraOpaqueData: + return e.Decode(r) + default: return fmt.Errorf("unknown type in ReadElement: %T", e) }