diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index da9daa69..57f2ad40 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -1,6 +1,7 @@ package lnwire import ( + "fmt" "io" "github.com/btcsuite/btcd/btcec" @@ -92,6 +93,17 @@ type AcceptChannel struct { // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + // + // NOTE: Since the upfront shutdown script MUST be present (though can + // be zero-length) if any TLV data is available, the script will be + // extracted and removed from this blob when decoding. ExtraData will + // contain all TLV records _except_ the DeliveryAddress record in that + // case. + ExtraData ExtraOpaqueData } // A compile time check to ensure AcceptChannel implements the lnwire.Message @@ -104,6 +116,15 @@ var _ Message = (*AcceptChannel)(nil) // // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { + // Since the upfront script is encoded as a TLV record, concatenate it + // with the ExtraData, and write them as one. + tlvRecords, err := packShutdownScript( + a.UpfrontShutdownScript, a.ExtraData, + ) + if err != nil { + return err + } + return WriteElements(w, a.PendingChannelID[:], a.DustLimit, @@ -119,7 +140,7 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { a.DelayedPaymentPoint, a.HtlcPoint, a.FirstCommitmentPoint, - a.UpfrontShutdownScript, + tlvRecords, ) } @@ -150,15 +171,82 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { return err } - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err = ReadElement(r, &a.UpfrontShutdownScript) - if err != nil && err != io.EOF { + // For backwards compatibility, the optional extra data blob for + // AcceptChannel must contain an entry for the upfront shutdown script. + // We'll read it out and attempt to parse it. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { return err } + + a.UpfrontShutdownScript, a.ExtraData, err = parseShutdownScript( + tlvRecords, + ) + if err != nil { + return err + } + return nil } +// packShutdownScript takes an upfront shutdown script and an opaque data blob +// and concatenates them. +func packShutdownScript(addr DeliveryAddress, extraData ExtraOpaqueData) ( + ExtraOpaqueData, error) { + + // We'll always write the upfront shutdown script record, regardless of + // the script being empty. + var tlvRecords ExtraOpaqueData + + // Pack it into a data blob as a TLV record. + err := tlvRecords.PackRecords(addr.NewRecord()) + if err != nil { + return nil, fmt.Errorf("unable to pack upfront shutdown "+ + "script as TLV record: %v", err) + } + + // Concatenate the remaining blob with the shutdown script record. + tlvRecords = append(tlvRecords, extraData...) + return tlvRecords, nil +} + +// parseShutdownScript reads and extract the upfront shutdown script from the +// passe data blob. It returns the script, if any, and the remainder of the +// data blob. +// +// This can be used to parse extra data for the OpenChannel and AcceptChannel +// messages, where the shutdown script is mandatory if extra TLV data is +// present. +func parseShutdownScript(tlvRecords ExtraOpaqueData) (DeliveryAddress, + ExtraOpaqueData, error) { + + // If no TLV data is present there can't be any script available. + if len(tlvRecords) == 0 { + return nil, tlvRecords, nil + } + + // Otherwise the shutdown script MUST be present. + var addr DeliveryAddress + tlvs, err := tlvRecords.ExtractRecords(addr.NewRecord()) + if err != nil { + return nil, nil, err + } + + // Not among TLV records, this means the data was invalid. + if _, ok := tlvs[DeliveryAddrType]; !ok { + return nil, nil, fmt.Errorf("no shutdown script in non-empty " + + "data blob") + } + + // Now that we have retrieved the address (which can be zero-length), + // we'll remove the bytes encoding it from the TLV data before + // returning it. + addrLen := len(addr) + tlvRecords = tlvRecords[addrLen+2:] + + return addr, tlvRecords, nil +} + // MsgType returns the MessageType code which uniquely identifies this message // as an AcceptChannel on the wire. // @@ -172,11 +260,5 @@ func (a *AcceptChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AcceptChannel) MaxPayloadLength(uint32) uint32 { - // 32 + (8 * 4) + (4 * 1) + (2 * 2) + (33 * 6) - var length uint32 = 270 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 0c8cf475..f9c48d38 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -376,6 +376,15 @@ func TestLightningWireProtocol(t *testing.T) { req.UpfrontShutdownScript = []byte{} } + // 1/2 chance how having more TLV data after the + // shutdown script. + if r.Intn(2) == 0 { + // TLV type 1 of length 2. + req.ExtraData = []byte{1, 2, 0xff, 0xff} + } else { + req.ExtraData = []byte{} + } + v[0] = reflect.ValueOf(req) }, MsgAcceptChannel: func(v []reflect.Value, r *rand.Rand) { @@ -436,6 +445,14 @@ func TestLightningWireProtocol(t *testing.T) { } else { req.UpfrontShutdownScript = []byte{} } + // 1/2 chance how having more TLV data after the + // shutdown script. + if r.Intn(2) == 0 { + // TLV type 1 of length 2. + req.ExtraData = []byte{1, 2, 0xff, 0xff} + } else { + req.ExtraData = []byte{} + } v[0] = reflect.ValueOf(req) }, diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index a165ef75..70dbe790 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -128,6 +128,17 @@ type OpenChannel struct { // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + // + // NOTE: Since the upfront shutdown script MUST be present (though can + // be zero-length) if any TLV data is available, the script will be + // extracted and removed from this blob when decoding. ExtraData will + // contain all TLV records _except_ the DeliveryAddress record in that + // case. + ExtraData ExtraOpaqueData } // A compile time check to ensure OpenChannel implements the lnwire.Message @@ -140,6 +151,15 @@ var _ Message = (*OpenChannel)(nil) // // This is part of the lnwire.Message interface. func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { + // Since the upfront script is encoded as a TLV record, concatenate it + // with the ExtraData, and write them as one. + tlvRecords, err := packShutdownScript( + o.UpfrontShutdownScript, o.ExtraData, + ) + if err != nil { + return err + } + return WriteElements(w, o.ChainHash[:], o.PendingChannelID[:], @@ -159,7 +179,7 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { o.HtlcPoint, o.FirstCommitmentPoint, o.ChannelFlags, - o.UpfrontShutdownScript, + tlvRecords, ) } @@ -169,7 +189,8 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { - if err := ReadElements(r, + // Read all the mandatory fields in the open message. + err := ReadElements(r, o.ChainHash[:], o.PendingChannelID[:], &o.FundingAmount, @@ -188,14 +209,23 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { &o.HtlcPoint, &o.FirstCommitmentPoint, &o.ChannelFlags, - ); err != nil { + ) + if err != nil { return err } - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err := ReadElement(r, &o.UpfrontShutdownScript) - if err != nil && err != io.EOF { + // For backwards compatibility, the optional extra data blob for + // OpenChannel must contain an entry for the upfront shutdown script. + // We'll read it out and attempt to parse it. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + o.UpfrontShutdownScript, o.ExtraData, err = parseShutdownScript( + tlvRecords, + ) + if err != nil { return err } @@ -215,11 +245,5 @@ func (o *OpenChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (o *OpenChannel) MaxPayloadLength(uint32) uint32 { - // (32 * 2) + (8 * 6) + (4 * 1) + (2 * 2) + (33 * 6) + 1 - var length uint32 = 319 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index e27681e4..8def329c 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -22,20 +22,6 @@ type Shutdown struct { ExtraData ExtraOpaqueData } -// DeliveryAddress is used to communicate the address to which funds from a -// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or -// p2wpkh. -type DeliveryAddress []byte - -// deliveryAddressMaxSize is the maximum expected size in bytes of a -// DeliveryAddress based on the types of scripts we know. -// Following are the known scripts and their sizes in bytes. -// - pay to witness script hash: 34 -// - pay to pubkey hash: 25 -// - pay to script hash: 22 -// - pay to witness pubkey hash: 22. -const deliveryAddressMaxSize = 34 - // NewShutdown creates a new Shutdown message. func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { return &Shutdown{ diff --git a/lnwire/typed_delivery_addr.go b/lnwire/typed_delivery_addr.go new file mode 100644 index 00000000..9ad53b1a --- /dev/null +++ b/lnwire/typed_delivery_addr.go @@ -0,0 +1,41 @@ +package lnwire + +import ( + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // DeliveryAddrType is the TLV record type for delivery addreses within + // the name space of the OpenChannel and AcceptChannel messages. + DeliveryAddrType = 0 + + // deliveryAddressMaxSize is the maximum expected size in bytes of a + // DeliveryAddress based on the types of scripts we know. + // Following are the known scripts and their sizes in bytes. + // - pay to witness script hash: 34 + // - pay to pubkey hash: 25 + // - pay to script hash: 22 + // - pay to witness pubkey hash: 22. + deliveryAddressMaxSize = 34 +) + +// DeliveryAddress is used to communicate the address to which funds from a +// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or +// p2wpkh. +type DeliveryAddress []byte + +// NewRecord returns a TLV record that can be used to encode the delivery +// address within the ExtraData TLV stream. This was intorudced in order to +// allow the OpenChannel/AcceptChannel messages to properly be extended with +// TLV types. +func (d *DeliveryAddress) NewRecord() tlv.Record { + addrBytes := (*[]byte)(d) + + return tlv.MakeDynamicRecord( + DeliveryAddrType, addrBytes, + func() uint64 { + return uint64(len(*addrBytes)) + }, + tlv.EVarBytes, tlv.DVarBytes, + ) +} diff --git a/lnwire/typed_delivery_addr_test.go b/lnwire/typed_delivery_addr_test.go new file mode 100644 index 00000000..d5d9c703 --- /dev/null +++ b/lnwire/typed_delivery_addr_test.go @@ -0,0 +1,37 @@ +package lnwire + +import ( + "bytes" + "testing" +) + +// TestDeliveryAddressEncodeDecode tests that we're able to properly +// encode and decode delivery addresses within TLV streams. +func TestDeliveryAddressEncodeDecode(t *testing.T) { + t.Parallel() + + addr := DeliveryAddress( + bytes.Repeat([]byte("a"), deliveryAddressMaxSize), + ) + + var extraData ExtraOpaqueData + err := extraData.PackRecords(addr.NewRecord()) + if err != nil { + t.Fatal(err) + } + + var addr2 DeliveryAddress + tlvs, err := extraData.ExtractRecords(addr2.NewRecord()) + if err != nil { + t.Fatal(err) + } + + if _, ok := tlvs[DeliveryAddrType]; !ok { + t.Fatalf("DeliveryAddrType not found in records") + } + + if !bytes.Equal(addr, addr2) { + t.Fatalf("addr mismatch: expected %x, got %x", addr[:], + addr2[:]) + } +}