diff --git a/.golangci.yml b/.golangci.yml index ee9246b2..5ffd56ac 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -8,6 +8,7 @@ run: skip-dirs: - channeldb/migration_01_to_11 + - channeldb/migration/lnwire21 build-tags: - autopilotrpc diff --git a/channeldb/channel.go b/channeldb/channel.go index 468847c7..d36ded21 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1791,14 +1791,19 @@ type LogUpdate struct { UpdateMsg lnwire.Message } -// Encode writes a log update to the provided io.Writer. -func (l *LogUpdate) Encode(w io.Writer) error { +// serializeLogUpdate writes a log update to the provided io.Writer. +func serializeLogUpdate(w io.Writer, l *LogUpdate) error { return WriteElements(w, l.LogIndex, l.UpdateMsg) } -// Decode reads a log update from the provided io.Reader. -func (l *LogUpdate) Decode(r io.Reader) error { - return ReadElements(r, &l.LogIndex, &l.UpdateMsg) +// deserializeLogUpdate reads a log update from the provided io.Reader. +func deserializeLogUpdate(r io.Reader) (*LogUpdate, error) { + l := &LogUpdate{} + if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil { + return nil, err + } + + return l, nil } // CircuitKey is used by a channel to uniquely identify the HTLCs it receives @@ -1960,12 +1965,12 @@ func deserializeLogUpdates(r io.Reader) ([]LogUpdate, error) { return logUpdates, nil } -func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { +func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { // nolint: dupl if err := serializeChanCommit(w, &diff.Commitment); err != nil { return err } - if err := diff.CommitSig.Encode(w, 0); err != nil { + if err := WriteElements(w, diff.CommitSig); err != nil { return err } @@ -2011,10 +2016,16 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { return nil, err } - d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { return nil, err } + commitSig, ok := msg.(*lnwire.CommitSig) + if !ok { + return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ + "read: %T", msg) + } + d.CommitSig = commitSig d.LogUpdates, err = deserializeLogUpdates(r) if err != nil { diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 5d203889..43747372 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -607,7 +607,8 @@ func TestChannelStateTransition(t *testing.T) { { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: lnwire.ChannelID{1, 2, 3}, + ChanID: lnwire.ChannelID{1, 2, 3}, + ExtraData: make([]byte, 0), }, }, } @@ -628,7 +629,9 @@ func TestChannelStateTransition(t *testing.T) { if !reflect.DeepEqual( dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0], ) { - t.Fatalf("unexpected update") + t.Fatalf("unexpected update: expected %v, got %v", + spew.Sdump(unsignedAckedUpdates[0]), + spew.Sdump(dbUnsignedAckedUpdates)) } // The balances, new update, the HTLCs and the changes to the fake @@ -670,22 +673,25 @@ func TestChannelStateTransition(t *testing.T) { wireSig, wireSig, }, + ExtraData: make([]byte, 0), }, LogUpdates: []LogUpdate{ { LogIndex: 1, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 1, - Amount: lnwire.NewMSatFromSatoshis(100), - Expiry: 25, + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + ExtraData: make([]byte, 0), }, }, { LogIndex: 2, UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 2, - Amount: lnwire.NewMSatFromSatoshis(200), - Expiry: 50, + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + ExtraData: make([]byte, 0), }, }, }, diff --git a/channeldb/codec.go b/channeldb/codec.go index f6903175..424f7c6e 100644 --- a/channeldb/codec.go +++ b/channeldb/codec.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "encoding/binary" "fmt" "io" @@ -178,7 +179,17 @@ func WriteElement(w io.Writer, element interface{}) error { } case lnwire.Message: - if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + var msgBuf bytes.Buffer + if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { + return err + } + + msgLen := uint16(len(msgBuf.Bytes())) + if err := WriteElements(w, msgLen); err != nil { + return err + } + + if _, err := w.Write(msgBuf.Bytes()); err != nil { return err } @@ -394,7 +405,13 @@ func ReadElement(r io.Reader, element interface{}) error { *e = bytes case *lnwire.Message: - msg, err := lnwire.ReadMessage(r, 0) + var msgLen uint16 + if err := ReadElement(r, &msgLen); err != nil { + return err + } + + msgReader := io.LimitReader(r, int64(msgLen)) + msg, err := lnwire.ReadMessage(msgReader, 0) if err != nil { return err } diff --git a/channeldb/db.go b/channeldb/db.go index d2618e10..76d0c37a 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -18,6 +18,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb/migration13" "github.com/lightningnetwork/lnd/channeldb/migration16" "github.com/lightningnetwork/lnd/channeldb/migration20" + "github.com/lightningnetwork/lnd/channeldb/migration21" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/lnwire" @@ -182,6 +183,12 @@ var ( number: 20, migration: migration20.MigrateOutpointIndex, }, + { + // Migrate to length prefixed wire messages everywhere + // in the database. + number: 21, + migration: migration21.MigrateDatabaseWireMessages, + }, } // Big endian is the preferred byte order, due to cursor scans over diff --git a/channeldb/forwarding_package.go b/channeldb/forwarding_package.go index dced6e95..fc080b0c 100644 --- a/channeldb/forwarding_package.go +++ b/channeldb/forwarding_package.go @@ -420,7 +420,7 @@ func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { } // AddFwdPkg writes a newly locked in forwarding package to disk. -func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error { +func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error { // nolint: dupl fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey) if err != nil { return err @@ -487,7 +487,7 @@ func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *FwdPkg) error { // putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *LogUpdate) error { var b bytes.Buffer - if err := htlc.Encode(&b); err != nil { + if err := serializeLogUpdate(&b, htlc); err != nil { return err } @@ -541,7 +541,7 @@ func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*FwdPkg, e return fwdPkgs, nil } -// loadFwPkg reads the packager's fwd pkg at a given height, and determines the +// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the // appropriate FwdState. func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, height uint64) (*FwdPkg, error) { @@ -652,12 +652,12 @@ func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, func loadHtlcs(bkt kvdb.RBucket) ([]LogUpdate, error) { var htlcs []LogUpdate if err := bkt.ForEach(func(_, v []byte) error { - var htlc LogUpdate - if err := htlc.Decode(bytes.NewReader(v)); err != nil { + htlc, err := deserializeLogUpdate(bytes.NewReader(v)) + if err != nil { return err } - htlcs = append(htlcs, htlc) + htlcs = append(htlcs, *htlc) return nil }); err != nil { diff --git a/channeldb/migration/lnwire21/accept_channel.go b/channeldb/migration/lnwire21/accept_channel.go new file mode 100644 index 00000000..da9daa69 --- /dev/null +++ b/channeldb/migration/lnwire21/accept_channel.go @@ -0,0 +1,182 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcutil" +) + +// AcceptChannel is the message Bob sends to Alice after she initiates the +// single funder channel workflow via an AcceptChannel message. Once Alice +// receives Bob's response, then she has all the items necessary to construct +// the funding transaction, and both commitment transactions. +type AcceptChannel struct { + // PendingChannelID serves to uniquely identify the future channel + // created by the initiated single funder workflow. + PendingChannelID [32]byte + + // DustLimit is the specific dust limit the sender of this message + // would like enforced on their version of the commitment transaction. + // Any output below this value will be "trimmed" from the commitment + // transaction, with the amount of the HTLC going to dust. + DustLimit btcutil.Amount + + // MaxValueInFlight represents the maximum amount of coins that can be + // pending within the channel at any given time. If the amount of funds + // in limbo exceeds this amount, then the channel will be failed. + MaxValueInFlight MilliSatoshi + + // ChannelReserve is the amount of BTC that the receiving party MUST + // maintain a balance above at all times. This is a safety mechanism to + // ensure that both sides always have skin in the game during the + // channel's lifetime. + ChannelReserve btcutil.Amount + + // HtlcMinimum is the smallest HTLC that the sender of this message + // will accept. + HtlcMinimum MilliSatoshi + + // MinAcceptDepth is the minimum depth that the initiator of the + // channel should wait before considering the channel open. + MinAcceptDepth uint32 + + // CsvDelay is the number of blocks to use for the relative time lock + // in the pay-to-self output of both commitment transactions. + CsvDelay uint16 + + // MaxAcceptedHTLCs is the total number of incoming HTLC's that the + // sender of this channel will accept. + // + // TODO(roasbeef): acks the initiator's, same with max in flight? + MaxAcceptedHTLCs uint16 + + // FundingKey is the key that should be used on behalf of the sender + // within the 2-of-2 multi-sig output that it contained within the + // funding transaction. + FundingKey *btcec.PublicKey + + // RevocationPoint is the base revocation point for the sending party. + // Any commitment transaction belonging to the receiver of this message + // should use this key and their per-commitment point to derive the + // revocation key for the commitment transaction. + RevocationPoint *btcec.PublicKey + + // PaymentPoint is the base payment point for the sending party. This + // key should be combined with the per commitment point for a + // particular commitment state in order to create the key that should + // be used in any output that pays directly to the sending party, and + // also within the HTLC covenant transactions. + PaymentPoint *btcec.PublicKey + + // DelayedPaymentPoint is the delay point for the sending party. This + // key should be combined with the per commitment point to derive the + // keys that are used in outputs of the sender's commitment transaction + // where they claim funds. + DelayedPaymentPoint *btcec.PublicKey + + // HtlcPoint is the base point used to derive the set of keys for this + // party that will be used within the HTLC public key scripts. This + // value is combined with the receiver's revocation base point in order + // to derive the keys that are used within HTLC scripts. + HtlcPoint *btcec.PublicKey + + // FirstCommitmentPoint is the first commitment point for the sending + // party. This value should be combined with the receiver's revocation + // base point in order to derive the revocation keys that are placed + // within the commitment transaction of the sender. + FirstCommitmentPoint *btcec.PublicKey + + // UpfrontShutdownScript is the script to which the channel funds should + // be paid when mutually closing the channel. This field is optional, and + // and has a length prefix, so a zero will be written if it is not set + // and its length followed by the script will be written if it is set. + UpfrontShutdownScript DeliveryAddress +} + +// A compile time check to ensure AcceptChannel implements the lnwire.Message +// interface. +var _ Message = (*AcceptChannel)(nil) + +// Encode serializes the target AcceptChannel into the passed io.Writer +// implementation. Serialization will observe the rules defined by the passed +// protocol version. +// +// This is part of the lnwire.Message interface. +func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + a.PendingChannelID[:], + a.DustLimit, + a.MaxValueInFlight, + a.ChannelReserve, + a.HtlcMinimum, + a.MinAcceptDepth, + a.CsvDelay, + a.MaxAcceptedHTLCs, + a.FundingKey, + a.RevocationPoint, + a.PaymentPoint, + a.DelayedPaymentPoint, + a.HtlcPoint, + a.FirstCommitmentPoint, + a.UpfrontShutdownScript, + ) +} + +// Decode deserializes the serialized AcceptChannel stored in the passed +// io.Reader into the target AcceptChannel using the deserialization rules +// defined by the passed protocol version. +// +// This is part of the lnwire.Message interface. +func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { + // Read all the mandatory fields in the accept message. + err := ReadElements(r, + a.PendingChannelID[:], + &a.DustLimit, + &a.MaxValueInFlight, + &a.ChannelReserve, + &a.HtlcMinimum, + &a.MinAcceptDepth, + &a.CsvDelay, + &a.MaxAcceptedHTLCs, + &a.FundingKey, + &a.RevocationPoint, + &a.PaymentPoint, + &a.DelayedPaymentPoint, + &a.HtlcPoint, + &a.FirstCommitmentPoint, + ) + if err != nil { + return err + } + + // Check for the optional upfront shutdown script field. If it is not there, + // silence the EOF error. + err = ReadElement(r, &a.UpfrontShutdownScript) + if err != nil && err != io.EOF { + return err + } + return nil +} + +// MsgType returns the MessageType code which uniquely identifies this message +// as an AcceptChannel on the wire. +// +// This is part of the lnwire.Message interface. +func (a *AcceptChannel) MsgType() MessageType { + return MsgAcceptChannel +} + +// MaxPayloadLength returns the maximum allowed payload length for a +// AcceptChannel message. +// +// This is part of the lnwire.Message interface. +func (a *AcceptChannel) MaxPayloadLength(uint32) uint32 { + // 32 + (8 * 4) + (4 * 1) + (2 * 2) + (33 * 6) + var length uint32 = 270 // base length + + // Upfront shutdown script max length. + length += 2 + deliveryAddressMaxSize + + return length +} diff --git a/channeldb/migration/lnwire21/announcement_signatures.go b/channeldb/migration/lnwire21/announcement_signatures.go new file mode 100644 index 00000000..639704de --- /dev/null +++ b/channeldb/migration/lnwire21/announcement_signatures.go @@ -0,0 +1,108 @@ +package lnwire + +import ( + "io" + "io/ioutil" +) + +// AnnounceSignatures is a direct message between two endpoints of a +// channel and serves as an opt-in mechanism to allow the announcement of +// the channel to the rest of the network. It contains the necessary +// signatures by the sender to construct the channel announcement message. +type AnnounceSignatures struct { + // ChannelID is the unique description of the funding transaction. + // Channel id is better for users and debugging and short channel id is + // used for quick test on existence of the particular utxo inside the + // block chain, because it contains information about block. + ChannelID ChannelID + + // ShortChannelID is the unique description of the funding + // transaction. It is constructed with the most significant 3 bytes + // as the block height, the next 3 bytes indicating the transaction + // index within the block, and the least significant two bytes + // indicating the output index which pays to the channel. + ShortChannelID ShortChannelID + + // NodeSignature is the signature which contains the signed announce + // channel message, by this signature we proof that we possess of the + // node pub key and creating the reference node_key -> bitcoin_key. + NodeSignature Sig + + // BitcoinSignature is the signature which contains the signed node + // public key, by this signature we proof that we possess of the + // bitcoin key and and creating the reverse reference bitcoin_key -> + // node_key. + BitcoinSignature Sig + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte +} + +// A compile time check to ensure AnnounceSignatures implements the +// lnwire.Message interface. +var _ Message = (*AnnounceSignatures)(nil) + +// Decode deserializes a serialized AnnounceSignatures stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { + err := ReadElements(r, + &a.ChannelID, + &a.ShortChannelID, + &a.NodeSignature, + &a.BitcoinSignature, + ) + if err != nil { + return err + } + + // Now that we've read out all the fields that we explicitly know of, + // we'll collect the remainder into the ExtraOpaqueData field. If there + // aren't any bytes, then we'll snip off the slice to avoid carrying + // around excess capacity. + a.ExtraOpaqueData, err = ioutil.ReadAll(r) + if err != nil { + return err + } + if len(a.ExtraOpaqueData) == 0 { + a.ExtraOpaqueData = nil + } + + return nil +} + +// Encode serializes the target AnnounceSignatures into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + a.ChannelID, + a.ShortChannelID, + a.NodeSignature, + a.BitcoinSignature, + a.ExtraOpaqueData, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures) MsgType() MessageType { + return MsgAnnounceSignatures +} + +// MaxPayloadLength returns the maximum allowed payload size for this message +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *AnnounceSignatures) MaxPayloadLength(pver uint32) uint32 { + return 65533 +} diff --git a/channeldb/migration/lnwire21/channel_announcement.go b/channeldb/migration/lnwire21/channel_announcement.go new file mode 100644 index 00000000..46efeed8 --- /dev/null +++ b/channeldb/migration/lnwire21/channel_announcement.go @@ -0,0 +1,160 @@ +package lnwire + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// ChannelAnnouncement message is used to announce the existence of a channel +// between two peers in the overlay, which is propagated by the discovery +// service over broadcast handler. +type ChannelAnnouncement struct { + // This signatures are used by nodes in order to create cross + // references between node's channel and node. Requiring both nodes + // to sign indicates they are both willing to route other payments via + // this node. + NodeSig1 Sig + NodeSig2 Sig + + // This signatures are used by nodes in order to create cross + // references between node's channel and node. Requiring the bitcoin + // signatures proves they control the channel. + BitcoinSig1 Sig + BitcoinSig2 Sig + + // Features is the feature vector that encodes the features supported + // by the target node. This field can be used to signal the type of the + // channel, or modifications to the fields that would normally follow + // this vector. + Features *RawFeatureVector + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + ChainHash chainhash.Hash + + // ShortChannelID is the unique description of the funding transaction, + // or where exactly it's located within the target blockchain. + ShortChannelID ShortChannelID + + // The public keys of the two nodes who are operating the channel, such + // that is NodeID1 the numerically-lesser than NodeID2 (ascending + // numerical order). + NodeID1 [33]byte + NodeID2 [33]byte + + // Public keys which corresponds to the keys which was declared in + // multisig funding transaction output. + BitcoinKey1 [33]byte + BitcoinKey2 [33]byte + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte +} + +// A compile time check to ensure ChannelAnnouncement implements the +// lnwire.Message interface. +var _ Message = (*ChannelAnnouncement)(nil) + +// Decode deserializes a serialized ChannelAnnouncement stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { + err := ReadElements(r, + &a.NodeSig1, + &a.NodeSig2, + &a.BitcoinSig1, + &a.BitcoinSig2, + &a.Features, + a.ChainHash[:], + &a.ShortChannelID, + &a.NodeID1, + &a.NodeID2, + &a.BitcoinKey1, + &a.BitcoinKey2, + ) + if err != nil { + return err + } + + // Now that we've read out all the fields that we explicitly know of, + // we'll collect the remainder into the ExtraOpaqueData field. If there + // aren't any bytes, then we'll snip off the slice to avoid carrying + // around excess capacity. + a.ExtraOpaqueData, err = ioutil.ReadAll(r) + if err != nil { + return err + } + if len(a.ExtraOpaqueData) == 0 { + a.ExtraOpaqueData = nil + } + + return nil +} + +// Encode serializes the target ChannelAnnouncement into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *ChannelAnnouncement) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + a.NodeSig1, + a.NodeSig2, + a.BitcoinSig1, + a.BitcoinSig2, + a.Features, + a.ChainHash[:], + a.ShortChannelID, + a.NodeID1, + a.NodeID2, + a.BitcoinKey1, + a.BitcoinKey2, + a.ExtraOpaqueData, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *ChannelAnnouncement) MsgType() MessageType { + return MsgChannelAnnouncement +} + +// MaxPayloadLength returns the maximum allowed payload size for this message +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *ChannelAnnouncement) MaxPayloadLength(pver uint32) uint32 { + return 65533 +} + +// DataToSign is used to retrieve part of the announcement message which should +// be signed. +func (a *ChannelAnnouncement) DataToSign() ([]byte, error) { + // We should not include the signatures itself. + var w bytes.Buffer + err := WriteElements(&w, + a.Features, + a.ChainHash[:], + a.ShortChannelID, + a.NodeID1, + a.NodeID2, + a.BitcoinKey1, + a.BitcoinKey2, + a.ExtraOpaqueData, + ) + if err != nil { + return nil, err + } + + return w.Bytes(), nil +} diff --git a/channeldb/migration/lnwire21/channel_id.go b/channeldb/migration/lnwire21/channel_id.go new file mode 100644 index 00000000..0a9e0822 --- /dev/null +++ b/channeldb/migration/lnwire21/channel_id.go @@ -0,0 +1,91 @@ +package lnwire + +import ( + "encoding/binary" + "encoding/hex" + "math" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" +) + +const ( + // MaxFundingTxOutputs is the maximum number of allowed outputs on a + // funding transaction within the protocol. This is due to the fact + // that we use 2-bytes to encode the index within the funding output + // during the funding workflow. Funding transaction with more outputs + // than this are considered invalid within the protocol. + MaxFundingTxOutputs = math.MaxUint16 +) + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// ConnectionWideID is an all-zero ChannelID, which is used to represent a +// message intended for all channels to specific peer. +var ConnectionWideID = ChannelID{} + +// String returns the string representation of the ChannelID. This is just the +// hex string encoding of the ChannelID itself. +func (c ChannelID) String() string { + return hex.EncodeToString(c[:]) +} + +// NewChanIDFromOutPoint converts a target OutPoint into a ChannelID that is +// usable within the network. In order to convert the OutPoint into a ChannelID, +// we XOR the lower 2-bytes of the txid within the OutPoint with the big-endian +// serialization of the Index of the OutPoint, truncated to 2-bytes. +func NewChanIDFromOutPoint(op *wire.OutPoint) ChannelID { + // First we'll copy the txid of the outpoint into our channel ID slice. + var cid ChannelID + copy(cid[:], op.Hash[:]) + + // With the txid copied over, we'll now XOR the lower 2-bytes of the + // partial channelID with big-endian serialization of output index. + xorTxid(&cid, uint16(op.Index)) + + return cid +} + +// xorTxid performs the transformation needed to transform an OutPoint into a +// ChannelID. To do this, we expect the cid parameter to contain the txid +// unaltered and the outputIndex to be the output index +func xorTxid(cid *ChannelID, outputIndex uint16) { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], outputIndex) + + cid[30] ^= buf[0] + cid[31] ^= buf[1] +} + +// GenPossibleOutPoints generates all the possible outputs given a channel ID. +// In order to generate these possible outpoints, we perform a brute-force +// search through the candidate output index space, performing a reverse +// mapping from channelID back to OutPoint. +func (c *ChannelID) GenPossibleOutPoints() [MaxFundingTxOutputs]wire.OutPoint { + var possiblePoints [MaxFundingTxOutputs]wire.OutPoint + for i := uint16(0); i < MaxFundingTxOutputs; i++ { + cidCopy := *c + xorTxid(&cidCopy, i) + + possiblePoints[i] = wire.OutPoint{ + Hash: chainhash.Hash(cidCopy), + Index: uint32(i), + } + } + + return possiblePoints +} + +// IsChanPoint returns true if the OutPoint passed corresponds to the target +// ChannelID. +func (c ChannelID) IsChanPoint(op *wire.OutPoint) bool { + candidateCid := NewChanIDFromOutPoint(op) + + return candidateCid == c +} diff --git a/channeldb/migration/lnwire21/channel_reestablish.go b/channeldb/migration/lnwire21/channel_reestablish.go new file mode 100644 index 00000000..6fa8f8ac --- /dev/null +++ b/channeldb/migration/lnwire21/channel_reestablish.go @@ -0,0 +1,166 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// ChannelReestablish is a message sent between peers that have an existing +// open channel upon connection reestablishment. This message allows both sides +// to report their local state, and their current knowledge of the state of the +// remote commitment chain. If a deviation is detected and can be recovered +// from, then the necessary messages will be retransmitted. If the level of +// desynchronization is irreconcilable, then the channel will be force closed. +type ChannelReestablish struct { + // ChanID is the channel ID of the channel state we're attempting to + // synchronize with the remote party. + ChanID ChannelID + + // NextLocalCommitHeight is the next local commitment height of the + // sending party. If the height of the sender's commitment chain from + // the receiver's Pov is one less that this number, then the sender + // should re-send the *exact* same proposed commitment. + // + // In other words, the receiver should re-send their last sent + // commitment iff: + // + // * NextLocalCommitHeight == remoteCommitChain.Height + // + // This covers the case of a lost commitment which was sent by the + // sender of this message, but never received by the receiver of this + // message. + NextLocalCommitHeight uint64 + + // RemoteCommitTailHeight is the height of the receiving party's + // unrevoked commitment from the PoV of the sender of this message. If + // the height of the receiver's commitment is *one more* than this + // value, then their prior RevokeAndAck message should be + // retransmitted. + // + // In other words, the receiver should re-send their last sent + // RevokeAndAck message iff: + // + // * localCommitChain.tail().Height == RemoteCommitTailHeight + 1 + // + // This covers the case of a lost revocation, wherein the receiver of + // the message sent a revocation for a prior state, but the sender of + // the message never fully processed it. + RemoteCommitTailHeight uint64 + + // LastRemoteCommitSecret is the last commitment secret that the + // receiving node has sent to the sending party. This will be the + // secret of the last revoked commitment transaction. Including this + // provides proof that the sending node at least knows of this state, + // as they couldn't have produced it if it wasn't sent, as the value + // can be authenticated by querying the shachain or the receiving + // party. + LastRemoteCommitSecret [32]byte + + // LocalUnrevokedCommitPoint is the commitment point used in the + // current un-revoked commitment transaction of the sending party. + LocalUnrevokedCommitPoint *btcec.PublicKey +} + +// A compile time check to ensure ChannelReestablish implements the +// lnwire.Message interface. +var _ Message = (*ChannelReestablish)(nil) + +// Encode serializes the target ChannelReestablish into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { + err := WriteElements(w, + a.ChanID, + a.NextLocalCommitHeight, + a.RemoteCommitTailHeight, + ) + if err != nil { + return err + } + + // If the commit point wasn't sent, then we won't write out any of the + // remaining fields as they're optional. + if a.LocalUnrevokedCommitPoint == nil { + return nil + } + + // Otherwise, we'll write out the remaining elements. + return WriteElements(w, a.LastRemoteCommitSecret[:], + a.LocalUnrevokedCommitPoint) +} + +// Decode deserializes a serialized ChannelReestablish stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { + err := ReadElements(r, + &a.ChanID, + &a.NextLocalCommitHeight, + &a.RemoteCommitTailHeight, + ) + if err != nil { + return err + } + + // This message has currently defined optional fields. As a result, + // we'll only proceed if there's still bytes remaining within the + // reader. + // + // We'll manually parse out the optional fields in order to be able to + // still utilize the io.Reader interface. + + // We'll first attempt to read the optional commit secret, if we're at + // the EOF, then this means the field wasn't included so we can exit + // early. + var buf [32]byte + _, err = io.ReadFull(r, buf[:32]) + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + // If the field is present, then we'll copy it over and proceed. + copy(a.LastRemoteCommitSecret[:], buf[:]) + + // We'll conclude by parsing out the commitment point. We don't check + // the error in this case, as it has included the commit secret, then + // they MUST also include the commit point. + return ReadElement(r, &a.LocalUnrevokedCommitPoint) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *ChannelReestablish) MsgType() MessageType { + return MsgChannelReestablish +} + +// MaxPayloadLength returns the maximum allowed payload size for this message +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 { + var length uint32 + + // ChanID - 32 bytes + length += 32 + + // NextLocalCommitHeight - 8 bytes + length += 8 + + // RemoteCommitTailHeight - 8 bytes + length += 8 + + // LastRemoteCommitSecret - 32 bytes + length += 32 + + // LocalUnrevokedCommitPoint - 33 bytes + length += 33 + + return length +} diff --git a/channeldb/migration/lnwire21/channel_update.go b/channeldb/migration/lnwire21/channel_update.go new file mode 100644 index 00000000..037f3d55 --- /dev/null +++ b/channeldb/migration/lnwire21/channel_update.go @@ -0,0 +1,258 @@ +package lnwire + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// ChanUpdateMsgFlags is a bitfield that signals whether optional fields are +// present in the ChannelUpdate. +type ChanUpdateMsgFlags uint8 + +const ( + // ChanUpdateOptionMaxHtlc is a bit that indicates whether the + // optional htlc_maximum_msat field is present in this ChannelUpdate. + ChanUpdateOptionMaxHtlc ChanUpdateMsgFlags = 1 << iota +) + +// String returns the bitfield flags as a string. +func (c ChanUpdateMsgFlags) String() string { + return fmt.Sprintf("%08b", c) +} + +// HasMaxHtlc returns true if the htlc_maximum_msat option bit is set in the +// message flags. +func (c ChanUpdateMsgFlags) HasMaxHtlc() bool { + return c&ChanUpdateOptionMaxHtlc != 0 +} + +// ChanUpdateChanFlags is a bitfield that signals various options concerning a +// particular channel edge. Each bit is to be examined in order to determine +// how the ChannelUpdate message is to be interpreted. +type ChanUpdateChanFlags uint8 + +const ( + // ChanUpdateDirection indicates the direction of a channel update. If + // this bit is set to 0 if Node1 (the node with the "smaller" Node ID) + // is updating the channel, and to 1 otherwise. + ChanUpdateDirection ChanUpdateChanFlags = 1 << iota + + // ChanUpdateDisabled is a bit that indicates if the channel edge + // selected by the ChanUpdateDirection bit is to be treated as being + // disabled. + ChanUpdateDisabled +) + +// IsDisabled determines whether the channel flags has the disabled bit set. +func (c ChanUpdateChanFlags) IsDisabled() bool { + return c&ChanUpdateDisabled == ChanUpdateDisabled +} + +// String returns the bitfield flags as a string. +func (c ChanUpdateChanFlags) String() string { + return fmt.Sprintf("%08b", c) +} + +// ChannelUpdate message is used after channel has been initially announced. +// Each side independently announces its fees and minimum expiry for HTLCs and +// other parameters. Also this message is used to redeclare initially set +// channel parameters. +type ChannelUpdate struct { + // Signature is used to validate the announced data and prove the + // ownership of node id. + Signature Sig + + // ChainHash denotes the target chain that this channel was opened + // within. This value should be the genesis hash of the target chain. + // Along with the short channel ID, this uniquely identifies the + // channel globally in a blockchain. + ChainHash chainhash.Hash + + // ShortChannelID is the unique description of the funding transaction. + ShortChannelID ShortChannelID + + // Timestamp allows ordering in the case of multiple announcements. We + // should ignore the message if timestamp is not greater than + // the last-received. + Timestamp uint32 + + // MessageFlags is a bitfield that describes whether optional fields + // are present in this update. Currently, the least-significant bit + // must be set to 1 if the optional field MaxHtlc is present. + MessageFlags ChanUpdateMsgFlags + + // ChannelFlags is a bitfield that describes additional meta-data + // concerning how the update is to be interpreted. Currently, the + // least-significant bit must be set to 0 if the creating node + // corresponds to the first node in the previously sent channel + // announcement and 1 otherwise. If the second bit is set, then the + // channel is set to be disabled. + ChannelFlags ChanUpdateChanFlags + + // TimeLockDelta is the minimum number of blocks this node requires to + // be added to the expiry of HTLCs. This is a security parameter + // determined by the node operator. This value represents the required + // gap between the time locks of the incoming and outgoing HTLC's set + // to this node. + TimeLockDelta uint16 + + // HtlcMinimumMsat is the minimum HTLC value which will be accepted. + HtlcMinimumMsat MilliSatoshi + + // BaseFee is the base fee that must be used for incoming HTLC's to + // this particular channel. This value will be tacked onto the required + // for a payment independent of the size of the payment. + BaseFee uint32 + + // FeeRate is the fee rate that will be charged per millionth of a + // satoshi. + FeeRate uint32 + + // HtlcMaximumMsat is the maximum HTLC value which will be accepted. + HtlcMaximumMsat MilliSatoshi + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte +} + +// A compile time check to ensure ChannelUpdate implements the lnwire.Message +// interface. +var _ Message = (*ChannelUpdate)(nil) + +// Decode deserializes a serialized ChannelUpdate stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { + err := ReadElements(r, + &a.Signature, + a.ChainHash[:], + &a.ShortChannelID, + &a.Timestamp, + &a.MessageFlags, + &a.ChannelFlags, + &a.TimeLockDelta, + &a.HtlcMinimumMsat, + &a.BaseFee, + &a.FeeRate, + ) + if err != nil { + return err + } + + // Now check whether the max HTLC field is present and read it if so. + if a.MessageFlags.HasMaxHtlc() { + if err := ReadElements(r, &a.HtlcMaximumMsat); err != nil { + return err + } + } + + // Now that we've read out all the fields that we explicitly know of, + // we'll collect the remainder into the ExtraOpaqueData field. If there + // aren't any bytes, then we'll snip off the slice to avoid carrying + // around excess capacity. + a.ExtraOpaqueData, err = ioutil.ReadAll(r) + if err != nil { + return err + } + if len(a.ExtraOpaqueData) == 0 { + a.ExtraOpaqueData = nil + } + + return nil +} + +// Encode serializes the target ChannelUpdate into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { + err := WriteElements(w, + a.Signature, + a.ChainHash[:], + a.ShortChannelID, + a.Timestamp, + a.MessageFlags, + a.ChannelFlags, + a.TimeLockDelta, + a.HtlcMinimumMsat, + a.BaseFee, + a.FeeRate, + ) + if err != nil { + return err + } + + // Now append optional fields if they are set. Currently, the only + // optional field is max HTLC. + if a.MessageFlags.HasMaxHtlc() { + if err := WriteElements(w, a.HtlcMaximumMsat); err != nil { + return err + } + } + + // Finally, append any extra opaque data. + return WriteElements(w, a.ExtraOpaqueData) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *ChannelUpdate) MsgType() MessageType { + return MsgChannelUpdate +} + +// MaxPayloadLength returns the maximum allowed payload size for this message +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *ChannelUpdate) MaxPayloadLength(pver uint32) uint32 { + return 65533 +} + +// DataToSign is used to retrieve part of the announcement message which should +// be signed. +func (a *ChannelUpdate) DataToSign() ([]byte, error) { + + // We should not include the signatures itself. + var w bytes.Buffer + err := WriteElements(&w, + a.ChainHash[:], + a.ShortChannelID, + a.Timestamp, + a.MessageFlags, + a.ChannelFlags, + a.TimeLockDelta, + a.HtlcMinimumMsat, + a.BaseFee, + a.FeeRate, + ) + if err != nil { + return nil, err + } + + // Now append optional fields if they are set. Currently, the only + // optional field is max HTLC. + if a.MessageFlags.HasMaxHtlc() { + if err := WriteElements(&w, a.HtlcMaximumMsat); err != nil { + return nil, err + } + } + + // Finally, append any extra opaque data. + if err := WriteElements(&w, a.ExtraOpaqueData); err != nil { + return nil, err + } + + return w.Bytes(), nil +} diff --git a/channeldb/migration/lnwire21/closing_signed.go b/channeldb/migration/lnwire21/closing_signed.go new file mode 100644 index 00000000..91b90646 --- /dev/null +++ b/channeldb/migration/lnwire21/closing_signed.go @@ -0,0 +1,88 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcutil" +) + +// ClosingSigned is sent by both parties to a channel once the channel is clear +// of HTLCs, and is primarily concerned with negotiating fees for the close +// transaction. Each party provides a signature for a transaction with a fee +// that they believe is fair. The process terminates when both sides agree on +// the same fee, or when one side force closes the channel. +// +// NOTE: The responder is able to send a signature without any additional +// messages as all transactions are assembled observing BIP 69 which defines a +// canonical ordering for input/outputs. Therefore, both sides are able to +// arrive at an identical closure transaction as they know the order of the +// inputs/outputs. +type ClosingSigned struct { + // ChannelID serves to identify which channel is to be closed. + ChannelID ChannelID + + // FeeSatoshis is the total fee in satoshis that the party to the + // channel would like to propose for the close transaction. + FeeSatoshis btcutil.Amount + + // Signature is for the proposed channel close transaction. + Signature Sig +} + +// NewClosingSigned creates a new empty ClosingSigned message. +func NewClosingSigned(cid ChannelID, fs btcutil.Amount, + sig Sig) *ClosingSigned { + + return &ClosingSigned{ + ChannelID: cid, + FeeSatoshis: fs, + Signature: sig, + } +} + +// A compile time check to ensure ClosingSigned implements the lnwire.Message +// interface. +var _ Message = (*ClosingSigned)(nil) + +// Decode deserializes a serialized ClosingSigned message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Signature) +} + +// Encode serializes the target ClosingSigned into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, c.ChannelID, c.FeeSatoshis, c.Signature) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ClosingSigned) MsgType() MessageType { + return MsgClosingSigned +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// ClosingSigned complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ClosingSigned) MaxPayloadLength(uint32) uint32 { + var length uint32 + + // ChannelID - 32 bytes + length += 32 + + // FeeSatoshis - 8 bytes + length += 8 + + // Signature - 64 bytes + length += 64 + + return length +} diff --git a/channeldb/migration/lnwire21/commit_sig.go b/channeldb/migration/lnwire21/commit_sig.go new file mode 100644 index 00000000..f15a9738 --- /dev/null +++ b/channeldb/migration/lnwire21/commit_sig.go @@ -0,0 +1,95 @@ +package lnwire + +import ( + "io" +) + +// CommitSig is sent by either side to stage any pending HTLC's in the +// receiver's pending set into a new commitment state. Implicitly, the new +// commitment transaction constructed which has been signed by CommitSig +// includes all HTLC's in the remote node's pending set. A CommitSig message +// may be sent after a series of UpdateAddHTLC/UpdateFulfillHTLC messages in +// order to batch add several HTLC's with a single signature covering all +// implicitly accepted HTLC's. +type CommitSig struct { + // ChanID uniquely identifies to which currently active channel this + // CommitSig applies to. + ChanID ChannelID + + // CommitSig is Alice's signature for Bob's new commitment transaction. + // Alice is able to send this signature without requesting any + // additional data due to the piggybacking of Bob's next revocation + // hash in his prior RevokeAndAck message, as well as the canonical + // ordering used for all inputs/outputs within commitment transactions. + // If initiating a new commitment state, this signature should ONLY + // cover all of the sending party's pending log updates, and the log + // updates of the remote party that have been ACK'd. + CommitSig Sig + + // HtlcSigs is a signature for each relevant HTLC output within the + // created commitment. The order of the signatures is expected to be + // identical to the placement of the HTLC's within the BIP 69 sorted + // commitment transaction. For each outgoing HTLC (from the PoV of the + // sender of this message), a signature for an HTLC timeout transaction + // should be signed, for each incoming HTLC the HTLC timeout + // transaction should be signed. + HtlcSigs []Sig +} + +// NewCommitSig creates a new empty CommitSig message. +func NewCommitSig() *CommitSig { + return &CommitSig{} +} + +// A compile time check to ensure CommitSig implements the lnwire.Message +// interface. +var _ Message = (*CommitSig)(nil) + +// Decode deserializes a serialized CommitSig message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *CommitSig) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.CommitSig, + &c.HtlcSigs, + ) +} + +// Encode serializes the target CommitSig into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *CommitSig) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.CommitSig, + c.HtlcSigs, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *CommitSig) MsgType() MessageType { + return MsgCommitSig +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// CommitSig complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *CommitSig) MaxPayloadLength(uint32) uint32 { + // 32 + 64 + 2 + max_allowed_htlcs + return MaxMessagePayload +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *CommitSig) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration/lnwire21/error.go b/channeldb/migration/lnwire21/error.go new file mode 100644 index 00000000..19911d1f --- /dev/null +++ b/channeldb/migration/lnwire21/error.go @@ -0,0 +1,139 @@ +package lnwire + +import ( + "fmt" + "io" +) + +// FundingError represents a set of errors that can be encountered and sent +// during the funding workflow. +type FundingError uint8 + +const ( + // ErrMaxPendingChannels is returned by remote peer when the number of + // active pending channels exceeds their maximum policy limit. + ErrMaxPendingChannels FundingError = 1 + + // ErrSynchronizingChain is returned by a remote peer that receives a + // channel update or a funding request while it's still syncing to the + // latest state of the blockchain. + ErrSynchronizingChain FundingError = 2 + + // ErrChanTooLarge is returned by a remote peer that receives a + // FundingOpen request for a channel that is above their current + // soft-limit. + ErrChanTooLarge FundingError = 3 +) + +// String returns a human readable version of the target FundingError. +func (e FundingError) String() string { + switch e { + case ErrMaxPendingChannels: + return "Number of pending channels exceed maximum" + case ErrSynchronizingChain: + return "Synchronizing blockchain" + case ErrChanTooLarge: + return "channel too large" + default: + return "unknown error" + } +} + +// Error returns the human readable version of the target FundingError. +// +// NOTE: Satisfies the Error interface. +func (e FundingError) Error() string { + return e.String() +} + +// ErrorData is a set of bytes associated with a particular sent error. A +// receiving node SHOULD only print out data verbatim if the string is composed +// solely of printable ASCII characters. For reference, the printable character +// set includes byte values 32 through 127 inclusive. +type ErrorData []byte + +// Error represents a generic error bound to an exact channel. The message +// format is purposefully general in order to allow expression of a wide array +// of possible errors. Each Error message is directed at a particular open +// channel referenced by ChannelPoint. +type Error struct { + // ChanID references the active channel in which the error occurred + // within. If the ChanID is all zeros, then this error applies to the + // entire established connection. + ChanID ChannelID + + // Data is the attached error data that describes the exact failure + // which caused the error message to be sent. + Data ErrorData +} + +// NewError creates a new Error message. +func NewError() *Error { + return &Error{} +} + +// A compile time check to ensure Error implements the lnwire.Message +// interface. +var _ Message = (*Error)(nil) + +// Error returns the string representation to Error. +// +// NOTE: Satisfies the error interface. +func (c *Error) Error() string { + errMsg := "non-ascii data" + if isASCII(c.Data) { + errMsg = string(c.Data) + } + + return fmt.Sprintf("chan_id=%v, err=%v", c.ChanID, errMsg) +} + +// Decode deserializes a serialized Error message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *Error) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.Data, + ) +} + +// Encode serializes the target Error into the passed io.Writer observing the +// protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *Error) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.Data, + ) +} + +// MsgType returns the integer uniquely identifying an Error message on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *Error) MsgType() MessageType { + return MsgError +} + +// MaxPayloadLength returns the maximum allowed payload size for an Error +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *Error) MaxPayloadLength(uint32) uint32 { + // 32 + 2 + 65501 + return MaxMessagePayload +} + +// isASCII is a helper method that checks whether all bytes in `data` would be +// printable ASCII characters if interpreted as a string. +func isASCII(data []byte) bool { + for _, c := range data { + if c < 32 || c > 126 { + return false + } + } + return true +} diff --git a/channeldb/migration/lnwire21/features.go b/channeldb/migration/lnwire21/features.go new file mode 100644 index 00000000..279e72c9 --- /dev/null +++ b/channeldb/migration/lnwire21/features.go @@ -0,0 +1,482 @@ +package lnwire + +import ( + "encoding/binary" + "errors" + "io" +) + +var ( + // ErrFeaturePairExists signals an error in feature vector construction + // where the opposing bit in a feature pair has already been set. + ErrFeaturePairExists = errors.New("feature pair exists") +) + +// FeatureBit represents a feature that can be enabled in either a local or +// global feature vector at a specific bit position. Feature bits follow the +// "it's OK to be odd" rule, where features at even bit positions must be known +// to a node receiving them from a peer while odd bits do not. In accordance, +// feature bits are usually assigned in pairs, first being assigned an odd bit +// position which may later be changed to the preceding even position once +// knowledge of the feature becomes required on the network. +type FeatureBit uint16 + +const ( + // DataLossProtectRequired is a feature bit that indicates that a peer + // *requires* the other party know about the data-loss-protect optional + // feature. If the remote peer does not know of such a feature, then + // the sending peer SHOLUD disconnect them. The data-loss-protect + // feature allows a peer that's lost partial data to recover their + // settled funds of the latest commitment state. + DataLossProtectRequired FeatureBit = 0 + + // DataLossProtectOptional is an optional feature bit that indicates + // that the sending peer knows of this new feature and can activate it + // it. The data-loss-protect feature allows a peer that's lost partial + // data to recover their settled funds of the latest commitment state. + DataLossProtectOptional FeatureBit = 1 + + // InitialRoutingSync is a local feature bit meaning that the receiving + // node should send a complete dump of routing information when a new + // connection is established. + InitialRoutingSync FeatureBit = 3 + + // UpfrontShutdownScriptRequired is a feature bit which indicates that a + // peer *requires* that the remote peer accept an upfront shutdown script to + // which payout is enforced on cooperative closes. + UpfrontShutdownScriptRequired FeatureBit = 4 + + // UpfrontShutdownScriptOptional is an optional feature bit which indicates + // that the peer will accept an upfront shutdown script to which payout is + // enforced on cooperative closes. + UpfrontShutdownScriptOptional FeatureBit = 5 + + // GossipQueriesRequired is a feature bit that indicates that the + // receiving peer MUST know of the set of features that allows nodes to + // more efficiently query the network view of peers on the network for + // reconciliation purposes. + GossipQueriesRequired FeatureBit = 6 + + // GossipQueriesOptional is an optional feature bit that signals that + // the setting peer knows of the set of features that allows more + // efficient network view reconciliation. + GossipQueriesOptional FeatureBit = 7 + + // TLVOnionPayloadRequired is a feature bit that indicates a node is + // able to decode the new TLV information included in the onion packet. + TLVOnionPayloadRequired FeatureBit = 8 + + // TLVOnionPayloadOptional is an optional feature bit that indicates a + // node is able to decode the new TLV information included in the onion + // packet. + TLVOnionPayloadOptional FeatureBit = 9 + + // StaticRemoteKeyRequired is a required feature bit that signals that + // within one's commitment transaction, the key used for the remote + // party's non-delay output should not be tweaked. + StaticRemoteKeyRequired FeatureBit = 12 + + // StaticRemoteKeyOptional is an optional feature bit that signals that + // within one's commitment transaction, the key used for the remote + // party's non-delay output should not be tweaked. + StaticRemoteKeyOptional FeatureBit = 13 + + // PaymentAddrRequired is a required feature bit that signals that a + // node requires payment addresses, which are used to mitigate probing + // attacks on the receiver of a payment. + PaymentAddrRequired FeatureBit = 14 + + // PaymentAddrOptional is an optional feature bit that signals that a + // node supports payment addresses, which are used to mitigate probing + // attacks on the receiver of a payment. + PaymentAddrOptional FeatureBit = 15 + + // MPPOptional is a required feature bit that signals that the receiver + // of a payment requires settlement of an invoice with more than one + // HTLC. + MPPRequired FeatureBit = 16 + + // MPPOptional is an optional feature bit that signals that the receiver + // of a payment supports settlement of an invoice with more than one + // HTLC. + MPPOptional FeatureBit = 17 + + // WumboChannelsRequired is a required feature bit that signals that a + // node is willing to accept channels larger than 2^24 satoshis. + WumboChannelsRequired FeatureBit = 18 + + // WumboChannelsOptional is an optional feature bit that signals that a + // node is willing to accept channels larger than 2^24 satoshis. + WumboChannelsOptional FeatureBit = 19 + + // AnchorsRequired is a required feature bit that signals that the node + // requires channels to be made using commitments having anchor + // outputs. + AnchorsRequired FeatureBit = 20 + + // AnchorsOptional is an optional feature bit that signals that the + // node supports channels to be made using commitments having anchor + // outputs. + AnchorsOptional FeatureBit = 21 + + // AnchorsZeroFeeHtlcTxRequired is a required feature bit that signals + // that the node requires channels having zero-fee second-level HTLC + // transactions, which also imply anchor commitments. + AnchorsZeroFeeHtlcTxRequired FeatureBit = 22 + + // AnchorsZeroFeeHtlcTxRequired is an optional feature bit that signals + // that the node supports channels having zero-fee second-level HTLC + // transactions, which also imply anchor commitments. + AnchorsZeroFeeHtlcTxOptional FeatureBit = 23 + + // maxAllowedSize is a maximum allowed size of feature vector. + // + // NOTE: Within the protocol, the maximum allowed message size is 65535 + // bytes for all messages. Accounting for the overhead within the feature + // message to signal the type of message, that leaves us with 65533 bytes + // for the init message itself. Next, we reserve 4 bytes to encode the + // lengths of both the local and global feature vectors, so 65529 bytes + // for the local and global features. Knocking off one byte for the sake + // of the calculation, that leads us to 32764 bytes for each feature + // vector, or 131056 different features. + maxAllowedSize = 32764 +) + +// IsRequired returns true if the feature bit is even, and false otherwise. +func (b FeatureBit) IsRequired() bool { + return b&0x01 == 0x00 +} + +// Features is a mapping of known feature bits to a descriptive name. All known +// feature bits must be assigned a name in this mapping, and feature bit pairs +// must be assigned together for correct behavior. +var Features = map[FeatureBit]string{ + DataLossProtectRequired: "data-loss-protect", + DataLossProtectOptional: "data-loss-protect", + InitialRoutingSync: "initial-routing-sync", + UpfrontShutdownScriptRequired: "upfront-shutdown-script", + UpfrontShutdownScriptOptional: "upfront-shutdown-script", + GossipQueriesRequired: "gossip-queries", + GossipQueriesOptional: "gossip-queries", + TLVOnionPayloadRequired: "tlv-onion", + TLVOnionPayloadOptional: "tlv-onion", + StaticRemoteKeyOptional: "static-remote-key", + StaticRemoteKeyRequired: "static-remote-key", + PaymentAddrOptional: "payment-addr", + PaymentAddrRequired: "payment-addr", + MPPOptional: "multi-path-payments", + MPPRequired: "multi-path-payments", + AnchorsRequired: "anchor-commitments", + AnchorsOptional: "anchor-commitments", + AnchorsZeroFeeHtlcTxRequired: "anchors-zero-fee-htlc-tx", + AnchorsZeroFeeHtlcTxOptional: "anchors-zero-fee-htlc-tx", + WumboChannelsRequired: "wumbo-channels", + WumboChannelsOptional: "wumbo-channels", +} + +// RawFeatureVector represents a set of feature bits as defined in BOLT-09. A +// RawFeatureVector itself just stores a set of bit flags but can be used to +// construct a FeatureVector which binds meaning to each bit. Feature vectors +// can be serialized and deserialized to/from a byte representation that is +// transmitted in Lightning network messages. +type RawFeatureVector struct { + features map[FeatureBit]bool +} + +// NewRawFeatureVector creates a feature vector with all of the feature bits +// given as arguments enabled. +func NewRawFeatureVector(bits ...FeatureBit) *RawFeatureVector { + fv := &RawFeatureVector{features: make(map[FeatureBit]bool)} + for _, bit := range bits { + fv.Set(bit) + } + return fv +} + +// Merges sets all feature bits in other on the receiver's feature vector. +func (fv *RawFeatureVector) Merge(other *RawFeatureVector) error { + for bit := range other.features { + err := fv.SafeSet(bit) + if err != nil { + return err + } + } + return nil +} + +// Clone makes a copy of a feature vector. +func (fv *RawFeatureVector) Clone() *RawFeatureVector { + newFeatures := NewRawFeatureVector() + for bit := range fv.features { + newFeatures.Set(bit) + } + return newFeatures +} + +// IsSet returns whether a particular feature bit is enabled in the vector. +func (fv *RawFeatureVector) IsSet(feature FeatureBit) bool { + return fv.features[feature] +} + +// Set marks a feature as enabled in the vector. +func (fv *RawFeatureVector) Set(feature FeatureBit) { + fv.features[feature] = true +} + +// SafeSet sets the chosen feature bit in the feature vector, but returns an +// error if the opposing feature bit is already set. This ensures both that we +// are creating properly structured feature vectors, and in some cases, that +// peers are sending properly encoded ones, i.e. it can't be both optional and +// required. +func (fv *RawFeatureVector) SafeSet(feature FeatureBit) error { + if _, ok := fv.features[feature^1]; ok { + return ErrFeaturePairExists + } + + fv.Set(feature) + return nil +} + +// Unset marks a feature as disabled in the vector. +func (fv *RawFeatureVector) Unset(feature FeatureBit) { + delete(fv.features, feature) +} + +// SerializeSize returns the number of bytes needed to represent feature vector +// in byte format. +func (fv *RawFeatureVector) SerializeSize() int { + // We calculate byte-length via the largest bit index. + return fv.serializeSize(8) +} + +// SerializeSize32 returns the number of bytes needed to represent feature +// vector in base32 format. +func (fv *RawFeatureVector) SerializeSize32() int { + // We calculate base32-length via the largest bit index. + return fv.serializeSize(5) +} + +// serializeSize returns the number of bytes required to encode the feature +// vector using at most width bits per encoded byte. +func (fv *RawFeatureVector) serializeSize(width int) int { + // Find the largest feature bit index + max := -1 + for feature := range fv.features { + index := int(feature) + if index > max { + max = index + } + } + if max == -1 { + return 0 + } + + return max/width + 1 +} + +// Encode writes the feature vector in byte representation. Every feature +// encoded as a bit, and the bit vector is serialized using the least number of +// bytes. Since the bit vector length is variable, the first two bytes of the +// serialization represent the length. +func (fv *RawFeatureVector) Encode(w io.Writer) error { + // Write length of feature vector. + var l [2]byte + length := fv.SerializeSize() + binary.BigEndian.PutUint16(l[:], uint16(length)) + if _, err := w.Write(l[:]); err != nil { + return err + } + + return fv.encode(w, length, 8) +} + +// EncodeBase256 writes the feature vector in base256 representation. Every +// feature is encoded as a bit, and the bit vector is serialized using the least +// number of bytes. +func (fv *RawFeatureVector) EncodeBase256(w io.Writer) error { + length := fv.SerializeSize() + return fv.encode(w, length, 8) +} + +// EncodeBase32 writes the feature vector in base32 representation. Every feature +// is encoded as a bit, and the bit vector is serialized using the least number of +// bytes. +func (fv *RawFeatureVector) EncodeBase32(w io.Writer) error { + length := fv.SerializeSize32() + return fv.encode(w, length, 5) +} + +// encode writes the feature vector +func (fv *RawFeatureVector) encode(w io.Writer, length, width int) error { + // Generate the data and write it. + data := make([]byte, length) + for feature := range fv.features { + byteIndex := int(feature) / width + bitIndex := int(feature) % width + data[length-byteIndex-1] |= 1 << uint(bitIndex) + } + + _, err := w.Write(data) + return err +} + +// Decode reads the feature vector from its byte representation. Every feature +// is encoded as a bit, and the bit vector is serialized using the least number +// of bytes. Since the bit vector length is variable, the first two bytes of the +// serialization represent the length. +func (fv *RawFeatureVector) Decode(r io.Reader) error { + // Read the length of the feature vector. + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + length := binary.BigEndian.Uint16(l[:]) + + return fv.decode(r, int(length), 8) +} + +// DecodeBase256 reads the feature vector from its base256 representation. Every +// feature encoded as a bit, and the bit vector is serialized using the least +// number of bytes. +func (fv *RawFeatureVector) DecodeBase256(r io.Reader, length int) error { + return fv.decode(r, length, 8) +} + +// DecodeBase32 reads the feature vector from its base32 representation. Every +// feature encoded as a bit, and the bit vector is serialized using the least +// number of bytes. +func (fv *RawFeatureVector) DecodeBase32(r io.Reader, length int) error { + return fv.decode(r, length, 5) +} + +// decode reads a feature vector from the next length bytes of the io.Reader, +// assuming each byte has width feature bits encoded per byte. +func (fv *RawFeatureVector) decode(r io.Reader, length, width int) error { + // Read the feature vector data. + data := make([]byte, length) + if _, err := io.ReadFull(r, data); err != nil { + return err + } + + // Set feature bits from parsed data. + bitsNumber := len(data) * width + for i := 0; i < bitsNumber; i++ { + byteIndex := int(i / width) + bitIndex := uint(i % width) + if (data[length-byteIndex-1]>>bitIndex)&1 == 1 { + fv.Set(FeatureBit(i)) + } + } + + return nil +} + +// FeatureVector represents a set of enabled features. The set stores +// information on enabled flags and metadata about the feature names. A feature +// vector is serializable to a compact byte representation that is included in +// Lightning network messages. +type FeatureVector struct { + *RawFeatureVector + featureNames map[FeatureBit]string +} + +// NewFeatureVector constructs a new FeatureVector from a raw feature vector +// and mapping of feature definitions. If the feature vector argument is nil, a +// new one will be constructed with no enabled features. +func NewFeatureVector(featureVector *RawFeatureVector, + featureNames map[FeatureBit]string) *FeatureVector { + + if featureVector == nil { + featureVector = NewRawFeatureVector() + } + return &FeatureVector{ + RawFeatureVector: featureVector, + featureNames: featureNames, + } +} + +// EmptyFeatureVector returns a feature vector with no bits set. +func EmptyFeatureVector() *FeatureVector { + return NewFeatureVector(nil, Features) +} + +// HasFeature returns whether a particular feature is included in the set. The +// feature can be seen as set either if the bit is set directly OR the queried +// bit has the same meaning as its corresponding even/odd bit, which is set +// instead. The second case is because feature bits are generally assigned in +// pairs where both the even and odd position represent the same feature. +func (fv *FeatureVector) HasFeature(feature FeatureBit) bool { + return fv.IsSet(feature) || + (fv.isFeatureBitPair(feature) && fv.IsSet(feature^1)) +} + +// RequiresFeature returns true if the referenced feature vector *requires* +// that the given required bit be set. This method can be used with both +// optional and required feature bits as a parameter. +func (fv *FeatureVector) RequiresFeature(feature FeatureBit) bool { + // If we weren't passed a required feature bit, then we'll flip the + // lowest bit to query for the required version of the feature. This + // lets callers pass in both the optional and required bits. + if !feature.IsRequired() { + feature ^= 1 + } + + return fv.IsSet(feature) +} + +// UnknownRequiredFeatures returns a list of feature bits set in the vector +// that are unknown and in an even bit position. Feature bits with an even +// index must be known to a node receiving the feature vector in a message. +func (fv *FeatureVector) UnknownRequiredFeatures() []FeatureBit { + var unknown []FeatureBit + for feature := range fv.features { + if feature%2 == 0 && !fv.IsKnown(feature) { + unknown = append(unknown, feature) + } + } + return unknown +} + +// Name returns a string identifier for the feature represented by this bit. If +// the bit does not represent a known feature, this returns a string indicating +// as such. +func (fv *FeatureVector) Name(bit FeatureBit) string { + name, known := fv.featureNames[bit] + if !known { + return "unknown" + } + return name +} + +// IsKnown returns whether this feature bit represents a known feature. +func (fv *FeatureVector) IsKnown(bit FeatureBit) bool { + _, known := fv.featureNames[bit] + return known +} + +// isFeatureBitPair returns whether this feature bit and its corresponding +// even/odd bit both represent the same feature. This may often be the case as +// bits are generally assigned in pairs, first being assigned an odd bit +// position then being promoted to an even bit position once the network is +// ready. +func (fv *FeatureVector) isFeatureBitPair(bit FeatureBit) bool { + name1, known1 := fv.featureNames[bit] + name2, known2 := fv.featureNames[bit^1] + return known1 && known2 && name1 == name2 +} + +// Features returns the set of raw features contained in the feature vector. +func (fv *FeatureVector) Features() map[FeatureBit]struct{} { + fs := make(map[FeatureBit]struct{}, len(fv.RawFeatureVector.features)) + for b := range fv.RawFeatureVector.features { + fs[b] = struct{}{} + } + return fs +} + +// Clone copies a feature vector, carrying over its feature bits. The feature +// names are not copied. +func (fv *FeatureVector) Clone() *FeatureVector { + features := fv.RawFeatureVector.Clone() + return NewFeatureVector(features, fv.featureNames) +} diff --git a/channeldb/migration/lnwire21/funding_created.go b/channeldb/migration/lnwire21/funding_created.go new file mode 100644 index 00000000..c14321ec --- /dev/null +++ b/channeldb/migration/lnwire21/funding_created.go @@ -0,0 +1,66 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/wire" +) + +// FundingCreated is sent from Alice (the initiator) to Bob (the responder), +// once Alice receives Bob's contributions as well as his channel constraints. +// Once bob receives this message, he'll gain access to an immediately +// broadcastable commitment transaction and will reply with a signature for +// Alice's version of the commitment transaction. +type FundingCreated struct { + // PendingChannelID serves to uniquely identify the future channel + // created by the initiated single funder workflow. + PendingChannelID [32]byte + + // FundingPoint is the outpoint of the funding transaction created by + // Alice. With this, Bob is able to generate both his version and + // Alice's version of the commitment transaction. + FundingPoint wire.OutPoint + + // CommitSig is Alice's signature from Bob's version of the commitment + // transaction. + CommitSig Sig +} + +// A compile time check to ensure FundingCreated implements the lnwire.Message +// interface. +var _ Message = (*FundingCreated)(nil) + +// Encode serializes the target FundingCreated into the passed io.Writer +// implementation. Serialization will observe the rules defined by the passed +// protocol version. +// +// This is part of the lnwire.Message interface. +func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig) +} + +// Decode deserializes the serialized FundingCreated stored in the passed +// io.Reader into the target FundingCreated using the deserialization rules +// defined by the passed protocol version. +// +// This is part of the lnwire.Message interface. +func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig) +} + +// MsgType returns the uint32 code which uniquely identifies this message as a +// FundingCreated on the wire. +// +// This is part of the lnwire.Message interface. +func (f *FundingCreated) MsgType() MessageType { + return MsgFundingCreated +} + +// MaxPayloadLength returns the maximum allowed payload length for a +// FundingCreated message. +// +// This is part of the lnwire.Message interface. +func (f *FundingCreated) MaxPayloadLength(uint32) uint32 { + // 32 + 32 + 2 + 64 + return 130 +} diff --git a/channeldb/migration/lnwire21/funding_locked.go b/channeldb/migration/lnwire21/funding_locked.go new file mode 100644 index 00000000..c441b0be --- /dev/null +++ b/channeldb/migration/lnwire21/funding_locked.go @@ -0,0 +1,83 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// FundingLocked is the message that both parties to a new channel creation +// send once they have observed the funding transaction being confirmed on the +// blockchain. FundingLocked contains the signatures necessary for the channel +// participants to advertise the existence of the channel to the rest of the +// network. +type FundingLocked struct { + // ChanID is the outpoint of the channel's funding transaction. This + // can be used to query for the channel in the database. + ChanID ChannelID + + // NextPerCommitmentPoint is the secret that can be used to revoke the + // next commitment transaction for the channel. + NextPerCommitmentPoint *btcec.PublicKey +} + +// NewFundingLocked creates a new FundingLocked message, populating it with the +// necessary IDs and revocation secret. +func NewFundingLocked(cid ChannelID, npcp *btcec.PublicKey) *FundingLocked { + return &FundingLocked{ + ChanID: cid, + NextPerCommitmentPoint: npcp, + } +} + +// A compile time check to ensure FundingLocked implements the lnwire.Message +// interface. +var _ Message = (*FundingLocked)(nil) + +// Decode deserializes the serialized FundingLocked message stored in the +// passed io.Reader into the target FundingLocked using the deserialization +// rules defined by the passed protocol version. +// +// This is part of the lnwire.Message interface. +func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.NextPerCommitmentPoint) +} + +// Encode serializes the target FundingLocked message into the passed io.Writer +// implementation. Serialization will observe the rules defined by the passed +// protocol version. +// +// This is part of the lnwire.Message interface. +func (c *FundingLocked) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.NextPerCommitmentPoint) +} + +// MsgType returns the uint32 code which uniquely identifies this message as a +// FundingLocked message on the wire. +// +// This is part of the lnwire.Message interface. +func (c *FundingLocked) MsgType() MessageType { + return MsgFundingLocked +} + +// MaxPayloadLength returns the maximum allowed payload length for a +// FundingLocked message. This is calculated by summing the max length of all +// the fields within a FundingLocked message. +// +// This is part of the lnwire.Message interface. +func (c *FundingLocked) MaxPayloadLength(uint32) uint32 { + var length uint32 + + // ChanID - 32 bytes + length += 32 + + // NextPerCommitmentPoint - 33 bytes + length += 33 + + // 65 bytes + return length +} diff --git a/channeldb/migration/lnwire21/funding_signed.go b/channeldb/migration/lnwire21/funding_signed.go new file mode 100644 index 00000000..620f8b37 --- /dev/null +++ b/channeldb/migration/lnwire21/funding_signed.go @@ -0,0 +1,55 @@ +package lnwire + +import "io" + +// FundingSigned is sent from Bob (the responder) to Alice (the initiator) +// after receiving the funding outpoint and her signature for Bob's version of +// the commitment transaction. +type FundingSigned struct { + // ChannelPoint is the particular active channel that this + // FundingSigned is bound to. + ChanID ChannelID + + // CommitSig is Bob's signature for Alice's version of the commitment + // transaction. + CommitSig Sig +} + +// A compile time check to ensure FundingSigned implements the lnwire.Message +// interface. +var _ Message = (*FundingSigned)(nil) + +// Encode serializes the target FundingSigned into the passed io.Writer +// implementation. Serialization will observe the rules defined by the passed +// protocol version. +// +// This is part of the lnwire.Message interface. +func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, f.ChanID, f.CommitSig) +} + +// Decode deserializes the serialized FundingSigned stored in the passed +// io.Reader into the target FundingSigned using the deserialization rules +// defined by the passed protocol version. +// +// This is part of the lnwire.Message interface. +func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, &f.ChanID, &f.CommitSig) +} + +// MsgType returns the uint32 code which uniquely identifies this message as a +// FundingSigned on the wire. +// +// This is part of the lnwire.Message interface. +func (f *FundingSigned) MsgType() MessageType { + return MsgFundingSigned +} + +// MaxPayloadLength returns the maximum allowed payload length for a +// FundingSigned message. +// +// This is part of the lnwire.Message interface. +func (f *FundingSigned) MaxPayloadLength(uint32) uint32 { + // 32 + 64 + return 96 +} diff --git a/channeldb/migration/lnwire21/gossip_timestamp_range.go b/channeldb/migration/lnwire21/gossip_timestamp_range.go new file mode 100644 index 00000000..3c28cd05 --- /dev/null +++ b/channeldb/migration/lnwire21/gossip_timestamp_range.go @@ -0,0 +1,80 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// GossipTimestampRange is a message that allows the sender to restrict the set +// of future gossip announcements sent by the receiver. Nodes should send this +// if they have the gossip-queries feature bit active. Nodes are able to send +// new GossipTimestampRange messages to replace the prior window. +type GossipTimestampRange struct { + // ChainHash denotes the chain that the sender wishes to restrict the + // set of received announcements of. + ChainHash chainhash.Hash + + // FirstTimestamp is the timestamp of the earliest announcement message + // that should be sent by the receiver. + FirstTimestamp uint32 + + // TimestampRange is the horizon beyond the FirstTimestamp that any + // announcement messages should be sent for. The receiving node MUST + // NOT send any announcements that have a timestamp greater than + // FirstTimestamp + TimestampRange. + TimestampRange uint32 +} + +// NewGossipTimestampRange creates a new empty GossipTimestampRange message. +func NewGossipTimestampRange() *GossipTimestampRange { + return &GossipTimestampRange{} +} + +// A compile time check to ensure GossipTimestampRange implements the +// lnwire.Message interface. +var _ Message = (*GossipTimestampRange)(nil) + +// Decode deserializes a serialized GossipTimestampRange message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + g.ChainHash[:], + &g.FirstTimestamp, + &g.TimestampRange, + ) +} + +// Encode serializes the target GossipTimestampRange into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + g.ChainHash[:], + g.FirstTimestamp, + g.TimestampRange, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) MsgType() MessageType { + return MsgGossipTimestampRange +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// GossipTimestampRange complete message observing the specified protocol +// version. +// +// This is part of the lnwire.Message interface. +func (g *GossipTimestampRange) MaxPayloadLength(uint32) uint32 { + // 32 + 4 + 4 + // + // TODO(roasbeef): update to 8 byte timestmaps? + return 40 +} diff --git a/channeldb/migration/lnwire21/init_message.go b/channeldb/migration/lnwire21/init_message.go new file mode 100644 index 00000000..e1ddbb01 --- /dev/null +++ b/channeldb/migration/lnwire21/init_message.go @@ -0,0 +1,73 @@ +package lnwire + +import "io" + +// Init is the first message reveals the features supported or required by this +// node. Nodes wait for receipt of the other's features to simplify error +// diagnosis where features are incompatible. Each node MUST wait to receive +// init before sending any other messages. +type Init struct { + // GlobalFeatures is a legacy feature vector used for backwards + // compatibility with older nodes. Any features defined here should be + // merged with those presented in Features. + GlobalFeatures *RawFeatureVector + + // Features is a feature vector containing the features supported by + // the remote node. + // + // NOTE: Older nodes may place some features in GlobalFeatures, but all + // new features are to be added in Features. When handling an Init + // message, any GlobalFeatures should be merged into the unified + // Features field. + Features *RawFeatureVector +} + +// NewInitMessage creates new instance of init message object. +func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { + return &Init{ + GlobalFeatures: gf, + Features: f, + } +} + +// A compile time check to ensure Init implements the lnwire.Message +// interface. +var _ Message = (*Init)(nil) + +// Decode deserializes a serialized Init message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (msg *Init) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &msg.GlobalFeatures, + &msg.Features, + ) +} + +// Encode serializes the target Init into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (msg *Init) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + msg.GlobalFeatures, + msg.Features, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (msg *Init) MsgType() MessageType { + return MsgInit +} + +// MaxPayloadLength returns the maximum allowed payload size for an Init +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (msg *Init) MaxPayloadLength(uint32) uint32 { + return 2 + 2 + maxAllowedSize + 2 + maxAllowedSize +} diff --git a/channeldb/migration/lnwire21/lnwire.go b/channeldb/migration/lnwire21/lnwire.go new file mode 100644 index 00000000..ca0e449e --- /dev/null +++ b/channeldb/migration/lnwire21/lnwire.go @@ -0,0 +1,845 @@ +package lnwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "image/color" + "io" + "math" + + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/tor" +) + +// MaxSliceLength is the maximum allowed length for any opaque byte slices in +// the wire protocol. +const MaxSliceLength = 65535 + +// PkScript is simple type definition which represents a raw serialized public +// key script. +type PkScript []byte + +// addressType specifies the network protocol and version that should be used +// when connecting to a node at a particular address. +type addressType uint8 + +const ( + // noAddr denotes a blank address. An address of this type indicates + // that a node doesn't have any advertised addresses. + noAddr addressType = 0 + + // tcp4Addr denotes an IPv4 TCP address. + tcp4Addr addressType = 1 + + // tcp6Addr denotes an IPv6 TCP address. + tcp6Addr addressType = 2 + + // v2OnionAddr denotes a version 2 Tor onion service address. + v2OnionAddr addressType = 3 + + // v3OnionAddr denotes a version 3 Tor (prop224) onion service address. + v3OnionAddr addressType = 4 +) + +// AddrLen returns the number of bytes that it takes to encode the target +// address. +func (a addressType) AddrLen() uint16 { + switch a { + case noAddr: + return 0 + case tcp4Addr: + return 6 + case tcp6Addr: + return 18 + case v2OnionAddr: + return 12 + case v3OnionAddr: + return 37 + default: + return 0 + } +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for the wire protocol. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +// +// TODO(roasbeef): this should eventually draw from a buffer pool for +// serialization. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case NodeAlias: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case ShortChanIDEncoding: + var b [1]byte + b[0] = uint8(e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case uint8: + var b [1]byte + b[0] = e + if _, err := w.Write(b[:]); err != nil { + return err + } + case FundingFlag: + var b [1]byte + b[0] = uint8(e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case uint16: + var b [2]byte + binary.BigEndian.PutUint16(b[:], e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case ChanUpdateMsgFlags: + var b [1]byte + b[0] = uint8(e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case ChanUpdateChanFlags: + var b [1]byte + b[0] = uint8(e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case MilliSatoshi: + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(e)) + if _, err := w.Write(b[:]); err != nil { + return err + } + case btcutil.Amount: + var b [8]byte + binary.BigEndian.PutUint64(b[:], uint64(e)) + if _, err := w.Write(b[:]); err != nil { + return err + } + case uint32: + var b [4]byte + binary.BigEndian.PutUint32(b[:], e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case uint64: + var b [8]byte + binary.BigEndian.PutUint64(b[:], e) + if _, err := w.Write(b[:]); err != nil { + return err + } + case *btcec.PublicKey: + if e == nil { + return fmt.Errorf("cannot write nil pubkey") + } + + var b [33]byte + serializedPubkey := e.SerializeCompressed() + copy(b[:], serializedPubkey) + if _, err := w.Write(b[:]); err != nil { + return err + } + case []Sig: + var b [2]byte + numSigs := uint16(len(e)) + binary.BigEndian.PutUint16(b[:], numSigs) + if _, err := w.Write(b[:]); err != nil { + return err + } + + for _, sig := range e { + if err := WriteElement(w, sig); err != nil { + return err + } + } + case Sig: + // Write buffer + if _, err := w.Write(e[:]); err != nil { + return err + } + case PingPayload: + var l [2]byte + binary.BigEndian.PutUint16(l[:], uint16(len(e))) + if _, err := w.Write(l[:]); err != nil { + return err + } + + if _, err := w.Write(e[:]); err != nil { + return err + } + case PongPayload: + var l [2]byte + binary.BigEndian.PutUint16(l[:], uint16(len(e))) + if _, err := w.Write(l[:]); err != nil { + return err + } + + if _, err := w.Write(e[:]); err != nil { + return err + } + case ErrorData: + var l [2]byte + binary.BigEndian.PutUint16(l[:], uint16(len(e))) + if _, err := w.Write(l[:]); err != nil { + return err + } + + if _, err := w.Write(e[:]); err != nil { + return err + } + case OpaqueReason: + var l [2]byte + binary.BigEndian.PutUint16(l[:], uint16(len(e))) + if _, err := w.Write(l[:]); err != nil { + return err + } + + if _, err := w.Write(e[:]); err != nil { + return err + } + case [33]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + case []byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + case PkScript: + // The largest script we'll accept is a p2wsh which is exactly + // 34 bytes long. + scriptLength := len(e) + if scriptLength > 34 { + return fmt.Errorf("'PkScript' too long") + } + + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + case *RawFeatureVector: + if e == nil { + return fmt.Errorf("cannot write nil feature vector") + } + + if err := e.Encode(w); err != nil { + return err + } + + case wire.OutPoint: + var h [32]byte + copy(h[:], e.Hash[:]) + if _, err := w.Write(h[:]); err != nil { + return err + } + + if e.Index > math.MaxUint16 { + return fmt.Errorf("index for outpoint (%v) is "+ + "greater than max index of %v", e.Index, + math.MaxUint16) + } + + var idx [2]byte + binary.BigEndian.PutUint16(idx[:], uint16(e.Index)) + if _, err := w.Write(idx[:]); err != nil { + return err + } + + case ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + case FailCode: + if err := WriteElement(w, uint16(e)); err != nil { + return err + } + case ShortChannelID: + // Check that field fit in 3 bytes and write the blockHeight + if e.BlockHeight > ((1 << 24) - 1) { + return errors.New("block height should fit in 3 bytes") + } + + var blockHeight [4]byte + binary.BigEndian.PutUint32(blockHeight[:], e.BlockHeight) + + if _, err := w.Write(blockHeight[1:]); err != nil { + return err + } + + // Check that field fit in 3 bytes and write the txIndex + if e.TxIndex > ((1 << 24) - 1) { + return errors.New("tx index should fit in 3 bytes") + } + + var txIndex [4]byte + binary.BigEndian.PutUint32(txIndex[:], e.TxIndex) + if _, err := w.Write(txIndex[1:]); err != nil { + return err + } + + // Write the txPosition + var txPosition [2]byte + binary.BigEndian.PutUint16(txPosition[:], e.TxPosition) + if _, err := w.Write(txPosition[:]); err != nil { + return err + } + + case *net.TCPAddr: + if e == nil { + return fmt.Errorf("cannot write nil TCPAddr") + } + + if e.IP.To4() != nil { + var descriptor [1]byte + descriptor[0] = uint8(tcp4Addr) + if _, err := w.Write(descriptor[:]); err != nil { + return err + } + + var ip [4]byte + copy(ip[:], e.IP.To4()) + if _, err := w.Write(ip[:]); err != nil { + return err + } + } else { + var descriptor [1]byte + descriptor[0] = uint8(tcp6Addr) + if _, err := w.Write(descriptor[:]); err != nil { + return err + } + var ip [16]byte + copy(ip[:], e.IP.To16()) + if _, err := w.Write(ip[:]); err != nil { + return err + } + } + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(e.Port)) + if _, err := w.Write(port[:]); err != nil { + return err + } + + case *tor.OnionAddr: + if e == nil { + return errors.New("cannot write nil onion address") + } + + var suffixIndex int + switch len(e.OnionService) { + case tor.V2Len: + descriptor := []byte{byte(v2OnionAddr)} + if _, err := w.Write(descriptor); err != nil { + return err + } + suffixIndex = tor.V2Len - tor.OnionSuffixLen + case tor.V3Len: + descriptor := []byte{byte(v3OnionAddr)} + if _, err := w.Write(descriptor); err != nil { + return err + } + suffixIndex = tor.V3Len - tor.OnionSuffixLen + default: + return errors.New("unknown onion service length") + } + + host, err := tor.Base32Encoding.DecodeString( + e.OnionService[:suffixIndex], + ) + if err != nil { + return err + } + if _, err := w.Write(host); err != nil { + return err + } + + var port [2]byte + binary.BigEndian.PutUint16(port[:], uint16(e.Port)) + if _, err := w.Write(port[:]); err != nil { + return err + } + + case []net.Addr: + // First, we'll encode all the addresses into an intermediate + // buffer. We need to do this in order to compute the total + // length of the addresses. + var addrBuf bytes.Buffer + for _, address := range e { + if err := WriteElement(&addrBuf, address); err != nil { + return err + } + } + + // With the addresses fully encoded, we can now write out the + // number of bytes needed to encode them. + addrLen := addrBuf.Len() + if err := WriteElement(w, uint16(addrLen)); err != nil { + return err + } + + // Finally, we'll write out the raw addresses themselves, but + // only if we have any bytes to write. + if addrLen > 0 { + if _, err := w.Write(addrBuf.Bytes()); err != nil { + return err + } + } + case color.RGBA: + if err := WriteElements(w, e.R, e.G, e.B); err != nil { + return err + } + + case DeliveryAddress: + var length [2]byte + binary.BigEndian.PutUint16(length[:], uint16(len(e))) + if _, err := w.Write(length[:]); err != nil { + return err + } + if _, err := w.Write(e[:]); err != nil { + return err + } + + case bool: + var b [1]byte + if e { + b[0] = 1 + } + if _, err := w.Write(b[:]); err != nil { + return err + } + default: + return fmt.Errorf("unknown type in WriteElement: %T", e) + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of lnwire. +func ReadElement(r io.Reader, element interface{}) error { + var err error + switch e := element.(type) { + case *bool: + var b [1]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + if b[0] == 1 { + *e = true + } + + case *NodeAlias: + var a [32]byte + if _, err := io.ReadFull(r, a[:]); err != nil { + return err + } + + alias, err := NewNodeAlias(string(a[:])) + if err != nil { + return err + } + + *e = alias + case *ShortChanIDEncoding: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = ShortChanIDEncoding(b[0]) + case *uint8: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = b[0] + case *FundingFlag: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = FundingFlag(b[0]) + case *uint16: + var b [2]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = binary.BigEndian.Uint16(b[:]) + case *ChanUpdateMsgFlags: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = ChanUpdateMsgFlags(b[0]) + case *ChanUpdateChanFlags: + var b [1]uint8 + if _, err := r.Read(b[:]); err != nil { + return err + } + *e = ChanUpdateChanFlags(b[0]) + case *uint32: + var b [4]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = binary.BigEndian.Uint32(b[:]) + case *uint64: + var b [8]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = binary.BigEndian.Uint64(b[:]) + case *MilliSatoshi: + var b [8]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = MilliSatoshi(int64(binary.BigEndian.Uint64(b[:]))) + case *btcutil.Amount: + var b [8]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + *e = btcutil.Amount(int64(binary.BigEndian.Uint64(b[:]))) + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err = io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + case **RawFeatureVector: + f := NewRawFeatureVector() + err = f.Decode(r) + if err != nil { + return err + } + + *e = f + + case *[]Sig: + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + numSigs := binary.BigEndian.Uint16(l[:]) + + var sigs []Sig + if numSigs > 0 { + sigs = make([]Sig, numSigs) + for i := 0; i < int(numSigs); i++ { + if err := ReadElement(r, &sigs[i]); err != nil { + return err + } + } + } + + *e = sigs + + case *Sig: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + case *OpaqueReason: + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + reasonLen := binary.BigEndian.Uint16(l[:]) + + *e = OpaqueReason(make([]byte, reasonLen)) + if _, err := io.ReadFull(r, *e); err != nil { + return err + } + case *ErrorData: + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + errorLen := binary.BigEndian.Uint16(l[:]) + + *e = ErrorData(make([]byte, errorLen)) + if _, err := io.ReadFull(r, *e); err != nil { + return err + } + case *PingPayload: + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + pingLen := binary.BigEndian.Uint16(l[:]) + + *e = PingPayload(make([]byte, pingLen)) + if _, err := io.ReadFull(r, *e); err != nil { + return err + } + case *PongPayload: + var l [2]byte + if _, err := io.ReadFull(r, l[:]); err != nil { + return err + } + pongLen := binary.BigEndian.Uint16(l[:]) + + *e = PongPayload(make([]byte, pongLen)) + if _, err := io.ReadFull(r, *e); err != nil { + return err + } + case *[33]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + case []byte: + if _, err := io.ReadFull(r, e); err != nil { + return err + } + case *PkScript: + pkScript, err := wire.ReadVarBytes(r, 0, 34, "pkscript") + if err != nil { + return err + } + *e = pkScript + case *wire.OutPoint: + var h [32]byte + if _, err = io.ReadFull(r, h[:]); err != nil { + return err + } + hash, err := chainhash.NewHash(h[:]) + if err != nil { + return err + } + + var idxBytes [2]byte + _, err = io.ReadFull(r, idxBytes[:]) + if err != nil { + return err + } + index := binary.BigEndian.Uint16(idxBytes[:]) + + *e = wire.OutPoint{ + Hash: *hash, + Index: uint32(index), + } + case *FailCode: + if err := ReadElement(r, (*uint16)(e)); err != nil { + return err + } + case *ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *ShortChannelID: + var blockHeight [4]byte + if _, err = io.ReadFull(r, blockHeight[1:]); err != nil { + return err + } + + var txIndex [4]byte + if _, err = io.ReadFull(r, txIndex[1:]); err != nil { + return err + } + + var txPosition [2]byte + if _, err = io.ReadFull(r, txPosition[:]); err != nil { + return err + } + + *e = ShortChannelID{ + BlockHeight: binary.BigEndian.Uint32(blockHeight[:]), + TxIndex: binary.BigEndian.Uint32(txIndex[:]), + TxPosition: binary.BigEndian.Uint16(txPosition[:]), + } + + case *[]net.Addr: + // First, we'll read the number of total bytes that have been + // used to encode the set of addresses. + var numAddrsBytes [2]byte + if _, err = io.ReadFull(r, numAddrsBytes[:]); err != nil { + return err + } + addrsLen := binary.BigEndian.Uint16(numAddrsBytes[:]) + + // With the number of addresses, read, we'll now pull in the + // buffer of the encoded addresses into memory. + addrs := make([]byte, addrsLen) + if _, err := io.ReadFull(r, addrs[:]); err != nil { + return err + } + addrBuf := bytes.NewReader(addrs) + + // Finally, we'll parse the remaining address payload in + // series, using the first byte to denote how to decode the + // address itself. + var ( + addresses []net.Addr + addrBytesRead uint16 + ) + + for addrBytesRead < addrsLen { + var descriptor [1]byte + if _, err = io.ReadFull(addrBuf, descriptor[:]); err != nil { + return err + } + + addrBytesRead++ + + var address net.Addr + switch aType := addressType(descriptor[0]); aType { + case noAddr: + addrBytesRead += aType.AddrLen() + continue + + case tcp4Addr: + var ip [4]byte + if _, err := io.ReadFull(addrBuf, ip[:]); err != nil { + return err + } + + var port [2]byte + if _, err := io.ReadFull(addrBuf, port[:]); err != nil { + return err + } + + address = &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } + addrBytesRead += aType.AddrLen() + + case tcp6Addr: + var ip [16]byte + if _, err := io.ReadFull(addrBuf, ip[:]); err != nil { + return err + } + + var port [2]byte + if _, err := io.ReadFull(addrBuf, port[:]); err != nil { + return err + } + + address = &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } + addrBytesRead += aType.AddrLen() + + case v2OnionAddr: + var h [tor.V2DecodedLen]byte + if _, err := io.ReadFull(addrBuf, h[:]); err != nil { + return err + } + + var p [2]byte + if _, err := io.ReadFull(addrBuf, p[:]); err != nil { + return err + } + + onionService := tor.Base32Encoding.EncodeToString(h[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + address = &tor.OnionAddr{ + OnionService: onionService, + Port: port, + } + addrBytesRead += aType.AddrLen() + + case v3OnionAddr: + var h [tor.V3DecodedLen]byte + if _, err := io.ReadFull(addrBuf, h[:]); err != nil { + return err + } + + var p [2]byte + if _, err := io.ReadFull(addrBuf, p[:]); err != nil { + return err + } + + onionService := tor.Base32Encoding.EncodeToString(h[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + address = &tor.OnionAddr{ + OnionService: onionService, + Port: port, + } + addrBytesRead += aType.AddrLen() + + default: + return &ErrUnknownAddrType{aType} + } + + addresses = append(addresses, address) + } + + *e = addresses + case *color.RGBA: + err := ReadElements(r, + &e.R, + &e.G, + &e.B, + ) + if err != nil { + return err + } + case *DeliveryAddress: + var addrLen [2]byte + if _, err = io.ReadFull(r, addrLen[:]); err != nil { + return err + } + length := binary.BigEndian.Uint16(addrLen[:]) + + var addrBytes [deliveryAddressMaxSize]byte + if length > deliveryAddressMaxSize { + return fmt.Errorf("cannot read %d bytes into addrBytes", length) + } + if _, err = io.ReadFull(r, addrBytes[:length]); err != nil { + return err + } + *e = addrBytes[:length] + default: + return fmt.Errorf("unknown type in ReadElement: %T", e) + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration/lnwire21/message.go b/channeldb/migration/lnwire21/message.go new file mode 100644 index 00000000..b5c27339 --- /dev/null +++ b/channeldb/migration/lnwire21/message.go @@ -0,0 +1,296 @@ +// Copyright (c) 2013-2017 The btcsuite developers +// Copyright (c) 2015-2016 The Decred developers +// code derived from https://github .com/btcsuite/btcd/blob/master/wire/message.go +// Copyright (C) 2015-2017 The Lightning Network Developers + +package lnwire + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" +) + +// MaxMessagePayload is the maximum bytes a message can be regardless of other +// individual limits imposed by messages themselves. +const MaxMessagePayload = 65535 // 65KB + +// MessageType is the unique 2 byte big-endian integer that indicates the type +// of message on the wire. All messages have a very simple header which +// consists simply of 2-byte message type. We omit a length field, and checksum +// as the Lightning Protocol is intended to be encapsulated within a +// confidential+authenticated cryptographic messaging protocol. +type MessageType uint16 + +// The currently defined message types within this current version of the +// Lightning protocol. +const ( + MsgInit MessageType = 16 + MsgError = 17 + MsgPing = 18 + MsgPong = 19 + MsgOpenChannel = 32 + MsgAcceptChannel = 33 + MsgFundingCreated = 34 + MsgFundingSigned = 35 + MsgFundingLocked = 36 + MsgShutdown = 38 + MsgClosingSigned = 39 + MsgUpdateAddHTLC = 128 + MsgUpdateFulfillHTLC = 130 + MsgUpdateFailHTLC = 131 + MsgCommitSig = 132 + MsgRevokeAndAck = 133 + MsgUpdateFee = 134 + MsgUpdateFailMalformedHTLC = 135 + MsgChannelReestablish = 136 + MsgChannelAnnouncement = 256 + MsgNodeAnnouncement = 257 + MsgChannelUpdate = 258 + MsgAnnounceSignatures = 259 + MsgQueryShortChanIDs = 261 + MsgReplyShortChanIDsEnd = 262 + MsgQueryChannelRange = 263 + MsgReplyChannelRange = 264 + MsgGossipTimestampRange = 265 +) + +// String return the string representation of message type. +func (t MessageType) String() string { + switch t { + case MsgInit: + return "Init" + case MsgOpenChannel: + return "MsgOpenChannel" + case MsgAcceptChannel: + return "MsgAcceptChannel" + case MsgFundingCreated: + return "MsgFundingCreated" + case MsgFundingSigned: + return "MsgFundingSigned" + case MsgFundingLocked: + return "FundingLocked" + case MsgShutdown: + return "Shutdown" + case MsgClosingSigned: + return "ClosingSigned" + case MsgUpdateAddHTLC: + return "UpdateAddHTLC" + case MsgUpdateFailHTLC: + return "UpdateFailHTLC" + case MsgUpdateFulfillHTLC: + return "UpdateFulfillHTLC" + case MsgCommitSig: + return "CommitSig" + case MsgRevokeAndAck: + return "RevokeAndAck" + case MsgUpdateFailMalformedHTLC: + return "UpdateFailMalformedHTLC" + case MsgChannelReestablish: + return "ChannelReestablish" + case MsgError: + return "Error" + case MsgChannelAnnouncement: + return "ChannelAnnouncement" + case MsgChannelUpdate: + return "ChannelUpdate" + case MsgNodeAnnouncement: + return "NodeAnnouncement" + case MsgPing: + return "Ping" + case MsgAnnounceSignatures: + return "AnnounceSignatures" + case MsgPong: + return "Pong" + case MsgUpdateFee: + return "UpdateFee" + case MsgQueryShortChanIDs: + return "QueryShortChanIDs" + case MsgReplyShortChanIDsEnd: + return "ReplyShortChanIDsEnd" + case MsgQueryChannelRange: + return "QueryChannelRange" + case MsgReplyChannelRange: + return "ReplyChannelRange" + case MsgGossipTimestampRange: + return "GossipTimestampRange" + default: + return "" + } +} + +// UnknownMessage is an implementation of the error interface that allows the +// creation of an error in response to an unknown message. +type UnknownMessage struct { + messageType MessageType +} + +// Error returns a human readable string describing the error. +// +// This is part of the error interface. +func (u *UnknownMessage) Error() string { + return fmt.Sprintf("unable to parse message of unknown type: %v", + u.messageType) +} + +// Serializable is an interface which defines a lightning wire serializable +// object. +type Serializable interface { + // Decode reads the bytes stream and converts it to the object. + Decode(io.Reader, uint32) error + + // Encode converts object to the bytes stream and write it into the + // writer. + Encode(io.Writer, uint32) error +} + +// Message is an interface that defines a lightning wire protocol message. The +// interface is general in order to allow implementing types full control over +// the representation of its data. +type Message interface { + Serializable + MsgType() MessageType + MaxPayloadLength(uint32) uint32 +} + +// makeEmptyMessage creates a new empty message of the proper concrete type +// based on the passed message type. +func makeEmptyMessage(msgType MessageType) (Message, error) { + var msg Message + + switch msgType { + case MsgInit: + msg = &Init{} + case MsgOpenChannel: + msg = &OpenChannel{} + case MsgAcceptChannel: + msg = &AcceptChannel{} + case MsgFundingCreated: + msg = &FundingCreated{} + case MsgFundingSigned: + msg = &FundingSigned{} + case MsgFundingLocked: + msg = &FundingLocked{} + case MsgShutdown: + msg = &Shutdown{} + case MsgClosingSigned: + msg = &ClosingSigned{} + case MsgUpdateAddHTLC: + msg = &UpdateAddHTLC{} + case MsgUpdateFailHTLC: + msg = &UpdateFailHTLC{} + case MsgUpdateFulfillHTLC: + msg = &UpdateFulfillHTLC{} + case MsgCommitSig: + msg = &CommitSig{} + case MsgRevokeAndAck: + msg = &RevokeAndAck{} + case MsgUpdateFee: + msg = &UpdateFee{} + case MsgUpdateFailMalformedHTLC: + msg = &UpdateFailMalformedHTLC{} + case MsgChannelReestablish: + msg = &ChannelReestablish{} + case MsgError: + msg = &Error{} + case MsgChannelAnnouncement: + msg = &ChannelAnnouncement{} + case MsgChannelUpdate: + msg = &ChannelUpdate{} + case MsgNodeAnnouncement: + msg = &NodeAnnouncement{} + case MsgPing: + msg = &Ping{} + case MsgAnnounceSignatures: + msg = &AnnounceSignatures{} + case MsgPong: + msg = &Pong{} + case MsgQueryShortChanIDs: + msg = &QueryShortChanIDs{} + case MsgReplyShortChanIDsEnd: + msg = &ReplyShortChanIDsEnd{} + case MsgQueryChannelRange: + msg = &QueryChannelRange{} + case MsgReplyChannelRange: + msg = &ReplyChannelRange{} + case MsgGossipTimestampRange: + msg = &GossipTimestampRange{} + default: + return nil, &UnknownMessage{msgType} + } + + return msg, nil +} + +// WriteMessage writes a lightning Message to w including the necessary header +// information and returns the number of bytes written. +func WriteMessage(w io.Writer, msg Message, pver uint32) (int, error) { + totalBytes := 0 + + // Encode the message payload itself into a temporary buffer. + // TODO(roasbeef): create buffer pool + var bw bytes.Buffer + if err := msg.Encode(&bw, pver); err != nil { + return totalBytes, err + } + payload := bw.Bytes() + lenp := len(payload) + + // Enforce maximum overall message payload. + if lenp > MaxMessagePayload { + return totalBytes, fmt.Errorf("message payload is too large - "+ + "encoded %d bytes, but maximum message payload is %d bytes", + lenp, MaxMessagePayload) + } + + // Enforce maximum message payload on the message type. + mpl := msg.MaxPayloadLength(pver) + if uint32(lenp) > mpl { + return totalBytes, fmt.Errorf("message payload is too large - "+ + "encoded %d bytes, but maximum message payload of "+ + "type %v is %d bytes", lenp, msg.MsgType(), mpl) + } + + // With the initial sanity checks complete, we'll now write out the + // message type itself. + var mType [2]byte + binary.BigEndian.PutUint16(mType[:], uint16(msg.MsgType())) + n, err := w.Write(mType[:]) + totalBytes += n + if err != nil { + return totalBytes, err + } + + // With the message type written, we'll now write out the raw payload + // itself. + n, err = w.Write(payload) + totalBytes += n + + return totalBytes, err +} + +// ReadMessage reads, validates, and parses the next Lightning message from r +// for the provided protocol version. +func ReadMessage(r io.Reader, pver uint32) (Message, error) { + // First, we'll read out the first two bytes of the message so we can + // create the proper empty message. + var mType [2]byte + if _, err := io.ReadFull(r, mType[:]); err != nil { + return nil, err + } + + msgType := MessageType(binary.BigEndian.Uint16(mType[:])) + + // Now that we know the target message type, we can create the proper + // empty message type and decode the message into it. + msg, err := makeEmptyMessage(msgType) + if err != nil { + return nil, err + } + if err := msg.Decode(r, pver); err != nil { + return nil, err + } + + return msg, nil +} diff --git a/channeldb/migration/lnwire21/msat.go b/channeldb/migration/lnwire21/msat.go new file mode 100644 index 00000000..d3789dfa --- /dev/null +++ b/channeldb/migration/lnwire21/msat.go @@ -0,0 +1,51 @@ +package lnwire + +import ( + "fmt" + + "github.com/btcsuite/btcutil" +) + +const ( + // mSatScale is a value that's used to scale satoshis to milli-satoshis, and + // the other way around. + mSatScale uint64 = 1000 + + // MaxMilliSatoshi is the maximum number of msats that can be expressed + // in this data type. + MaxMilliSatoshi = ^MilliSatoshi(0) +) + +// MilliSatoshi are the native unit of the Lightning Network. A milli-satoshi +// is simply 1/1000th of a satoshi. There are 1000 milli-satoshis in a single +// satoshi. Within the network, all HTLC payments are denominated in +// milli-satoshis. As milli-satoshis aren't deliverable on the native +// blockchain, before settling to broadcasting, the values are rounded down to +// the nearest satoshi. +type MilliSatoshi uint64 + +// NewMSatFromSatoshis creates a new MilliSatoshi instance from a target amount +// of satoshis. +func NewMSatFromSatoshis(sat btcutil.Amount) MilliSatoshi { + return MilliSatoshi(uint64(sat) * mSatScale) +} + +// ToBTC converts the target MilliSatoshi amount to its corresponding value +// when expressed in BTC. +func (m MilliSatoshi) ToBTC() float64 { + sat := m.ToSatoshis() + return sat.ToBTC() +} + +// ToSatoshis converts the target MilliSatoshi amount to satoshis. Simply, this +// sheds a factor of 1000 from the mSAT amount in order to convert it to SAT. +func (m MilliSatoshi) ToSatoshis() btcutil.Amount { + return btcutil.Amount(uint64(m) / mSatScale) +} + +// String returns the string representation of the mSAT amount. +func (m MilliSatoshi) String() string { + return fmt.Sprintf("%v mSAT", uint64(m)) +} + +// TODO(roasbeef): extend with arithmetic operations? diff --git a/channeldb/migration/lnwire21/netaddress.go b/channeldb/migration/lnwire21/netaddress.go new file mode 100644 index 00000000..f31ac1f9 --- /dev/null +++ b/channeldb/migration/lnwire21/netaddress.go @@ -0,0 +1,54 @@ +package lnwire + +import ( + "fmt" + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" +) + +// NetAddress represents information pertaining to the identity and network +// reachability of a peer. Information stored includes the node's identity +// public key for establishing a confidential+authenticated connection, the +// service bits it supports, and a TCP address the node is reachable at. +// +// TODO(roasbeef): merge with LinkNode in some fashion +type NetAddress struct { + // IdentityKey is the long-term static public key for a node. This node is + // used throughout the network as a node's identity key. It is used to + // authenticate any data sent to the network on behalf of the node, and + // additionally to establish a confidential+authenticated connection with + // the node. + IdentityKey *btcec.PublicKey + + // Address is the IP address and port of the node. This is left + // general so that multiple implementations can be used. + Address net.Addr + + // ChainNet is the Bitcoin network this node is associated with. + // TODO(roasbeef): make a slice in the future for multi-chain + ChainNet wire.BitcoinNet +} + +// A compile time assertion to ensure that NetAddress meets the net.Addr +// interface. +var _ net.Addr = (*NetAddress)(nil) + +// String returns a human readable string describing the target NetAddress. The +// current string format is: @host. +// +// This part of the net.Addr interface. +func (n *NetAddress) String() string { + // TODO(roasbeef): use base58? + pubkey := n.IdentityKey.SerializeCompressed() + + return fmt.Sprintf("%x@%v", pubkey, n.Address) +} + +// Network returns the name of the network this address is bound to. +// +// This part of the net.Addr interface. +func (n *NetAddress) Network() string { + return n.Address.Network() +} diff --git a/channeldb/migration/lnwire21/node_announcement.go b/channeldb/migration/lnwire21/node_announcement.go new file mode 100644 index 00000000..35534352 --- /dev/null +++ b/channeldb/migration/lnwire21/node_announcement.go @@ -0,0 +1,192 @@ +package lnwire + +import ( + "bytes" + "fmt" + "image/color" + "io" + "io/ioutil" + "net" + "unicode/utf8" +) + +// ErrUnknownAddrType is an error returned if we encounter an unknown address type +// when parsing addresses. +type ErrUnknownAddrType struct { + addrType addressType +} + +// Error returns a human readable string describing the error. +// +// NOTE: implements the error interface. +func (e ErrUnknownAddrType) Error() string { + return fmt.Sprintf("unknown address type: %v", e.addrType) +} + +// ErrInvalidNodeAlias is an error returned if a node alias we parse on the +// wire is invalid, as in it has non UTF-8 characters. +type ErrInvalidNodeAlias struct{} + +// Error returns a human readable string describing the error. +// +// NOTE: implements the error interface. +func (e ErrInvalidNodeAlias) Error() string { + return "node alias has non-utf8 characters" +} + +// NodeAlias is a hex encoded UTF-8 string that may be displayed as an +// alternative to the node's ID. Notice that aliases are not unique and may be +// freely chosen by the node operators. +type NodeAlias [32]byte + +// NewNodeAlias creates a new instance of a NodeAlias. Verification is +// performed on the passed string to ensure it meets the alias requirements. +func NewNodeAlias(s string) (NodeAlias, error) { + var n NodeAlias + + if len(s) > 32 { + return n, fmt.Errorf("alias too large: max is %v, got %v", 32, + len(s)) + } + + if !utf8.ValidString(s) { + return n, &ErrInvalidNodeAlias{} + } + + copy(n[:], []byte(s)) + return n, nil +} + +// String returns a utf8 string representation of the alias bytes. +func (n NodeAlias) String() string { + // Trim trailing zero-bytes for presentation + return string(bytes.Trim(n[:], "\x00")) +} + +// NodeAnnouncement message is used to announce the presence of a Lightning +// node and also to signal that the node is accepting incoming connections. +// Each NodeAnnouncement authenticating the advertised information within the +// announcement via a signature using the advertised node pubkey. +type NodeAnnouncement struct { + // Signature is used to prove the ownership of node id. + Signature Sig + + // Features is the list of protocol features this node supports. + Features *RawFeatureVector + + // Timestamp allows ordering in the case of multiple announcements. + Timestamp uint32 + + // NodeID is a public key which is used as node identification. + NodeID [33]byte + + // RGBColor is used to customize their node's appearance in maps and + // graphs + RGBColor color.RGBA + + // Alias is used to customize their node's appearance in maps and + // graphs + Alias NodeAlias + + // Address includes two specification fields: 'ipv6' and 'port' on + // which the node is accepting incoming connections. + Addresses []net.Addr + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte +} + +// A compile time check to ensure NodeAnnouncement implements the +// lnwire.Message interface. +var _ Message = (*NodeAnnouncement)(nil) + +// Decode deserializes a serialized NodeAnnouncement stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { + err := ReadElements(r, + &a.Signature, + &a.Features, + &a.Timestamp, + &a.NodeID, + &a.RGBColor, + &a.Alias, + &a.Addresses, + ) + if err != nil { + return err + } + + // Now that we've read out all the fields that we explicitly know of, + // we'll collect the remainder into the ExtraOpaqueData field. If there + // aren't any bytes, then we'll snip off the slice to avoid carrying + // around excess capacity. + a.ExtraOpaqueData, err = ioutil.ReadAll(r) + if err != nil { + return err + } + if len(a.ExtraOpaqueData) == 0 { + a.ExtraOpaqueData = nil + } + + return nil +} + +// Encode serializes the target NodeAnnouncement into the passed io.Writer +// observing the protocol version specified. +// +func (a *NodeAnnouncement) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + a.Signature, + a.Features, + a.Timestamp, + a.NodeID, + a.RGBColor, + a.Alias, + a.Addresses, + a.ExtraOpaqueData, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (a *NodeAnnouncement) MsgType() MessageType { + return MsgNodeAnnouncement +} + +// MaxPayloadLength returns the maximum allowed payload size for this message +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (a *NodeAnnouncement) MaxPayloadLength(pver uint32) uint32 { + return 65533 +} + +// DataToSign returns the part of the message that should be signed. +func (a *NodeAnnouncement) DataToSign() ([]byte, error) { + + // We should not include the signatures itself. + var w bytes.Buffer + err := WriteElements(&w, + a.Features, + a.Timestamp, + a.NodeID, + a.RGBColor, + a.Alias[:], + a.Addresses, + a.ExtraOpaqueData, + ) + if err != nil { + return nil, err + } + + return w.Bytes(), nil +} diff --git a/channeldb/migration/lnwire21/onion_error.go b/channeldb/migration/lnwire21/onion_error.go new file mode 100644 index 00000000..35555e26 --- /dev/null +++ b/channeldb/migration/lnwire21/onion_error.go @@ -0,0 +1,1428 @@ +package lnwire + +import ( + "bufio" + "crypto/sha256" + "encoding/binary" + "fmt" + "io" + + "bytes" + + "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/tlv" +) + +// FailureMessage represents the onion failure object identified by its unique +// failure code. +type FailureMessage interface { + // Code returns a failure code describing the exact nature of the + // error. + Code() FailCode + + // Error returns a human readable string describing the error. With + // this method, the FailureMessage interface meets the built-in error + // interface. + Error() string +} + +// FailureMessageLength is the size of the failure message plus the size of +// padding. The FailureMessage message should always be EXACTLY this size. +const FailureMessageLength = 256 + +const ( + // FlagBadOnion error flag describes an unparsable, encrypted by + // previous node. + FlagBadOnion FailCode = 0x8000 + + // FlagPerm error flag indicates a permanent failure. + FlagPerm FailCode = 0x4000 + + // FlagNode error flag indicates a node failure. + FlagNode FailCode = 0x2000 + + // FlagUpdate error flag indicates a new channel update is enclosed + // within the error. + FlagUpdate FailCode = 0x1000 +) + +// FailCode specifies the precise reason that an upstream HTLC was canceled. +// Each UpdateFailHTLC message carries a FailCode which is to be passed +// backwards, encrypted at each step back to the source of the HTLC within the +// route. +type FailCode uint16 + +// The currently defined onion failure types within this current version of the +// Lightning protocol. +const ( + CodeNone FailCode = 0 + CodeInvalidRealm = FlagBadOnion | 1 + CodeTemporaryNodeFailure = FlagNode | 2 + CodePermanentNodeFailure = FlagPerm | FlagNode | 2 + CodeRequiredNodeFeatureMissing = FlagPerm | FlagNode | 3 + CodeInvalidOnionVersion = FlagBadOnion | FlagPerm | 4 + CodeInvalidOnionHmac = FlagBadOnion | FlagPerm | 5 + CodeInvalidOnionKey = FlagBadOnion | FlagPerm | 6 + CodeTemporaryChannelFailure = FlagUpdate | 7 + CodePermanentChannelFailure = FlagPerm | 8 + CodeRequiredChannelFeatureMissing = FlagPerm | 9 + CodeUnknownNextPeer = FlagPerm | 10 + CodeAmountBelowMinimum = FlagUpdate | 11 + CodeFeeInsufficient = FlagUpdate | 12 + CodeIncorrectCltvExpiry = FlagUpdate | 13 + CodeExpiryTooSoon = FlagUpdate | 14 + CodeChannelDisabled = FlagUpdate | 20 + CodeIncorrectOrUnknownPaymentDetails = FlagPerm | 15 + CodeIncorrectPaymentAmount = FlagPerm | 16 + CodeFinalExpiryTooSoon FailCode = 17 + CodeFinalIncorrectCltvExpiry FailCode = 18 + CodeFinalIncorrectHtlcAmount FailCode = 19 + CodeExpiryTooFar FailCode = 21 + CodeInvalidOnionPayload = FlagPerm | 22 + CodeMPPTimeout FailCode = 23 +) + +// String returns the string representation of the failure code. +func (c FailCode) String() string { + switch c { + case CodeInvalidRealm: + return "InvalidRealm" + + case CodeTemporaryNodeFailure: + return "TemporaryNodeFailure" + + case CodePermanentNodeFailure: + return "PermanentNodeFailure" + + case CodeRequiredNodeFeatureMissing: + return "RequiredNodeFeatureMissing" + + case CodeInvalidOnionVersion: + return "InvalidOnionVersion" + + case CodeInvalidOnionHmac: + return "InvalidOnionHmac" + + case CodeInvalidOnionKey: + return "InvalidOnionKey" + + case CodeTemporaryChannelFailure: + return "TemporaryChannelFailure" + + case CodePermanentChannelFailure: + return "PermanentChannelFailure" + + case CodeRequiredChannelFeatureMissing: + return "RequiredChannelFeatureMissing" + + case CodeUnknownNextPeer: + return "UnknownNextPeer" + + case CodeAmountBelowMinimum: + return "AmountBelowMinimum" + + case CodeFeeInsufficient: + return "FeeInsufficient" + + case CodeIncorrectCltvExpiry: + return "IncorrectCltvExpiry" + + case CodeIncorrectPaymentAmount: + return "IncorrectPaymentAmount" + + case CodeExpiryTooSoon: + return "ExpiryTooSoon" + + case CodeChannelDisabled: + return "ChannelDisabled" + + case CodeIncorrectOrUnknownPaymentDetails: + return "IncorrectOrUnknownPaymentDetails" + + case CodeFinalExpiryTooSoon: + return "FinalExpiryTooSoon" + + case CodeFinalIncorrectCltvExpiry: + return "FinalIncorrectCltvExpiry" + + case CodeFinalIncorrectHtlcAmount: + return "FinalIncorrectHtlcAmount" + + case CodeExpiryTooFar: + return "ExpiryTooFar" + + case CodeInvalidOnionPayload: + return "InvalidOnionPayload" + + case CodeMPPTimeout: + return "MPPTimeout" + + default: + return "" + } +} + +// FailInvalidRealm is returned if the realm byte is unknown. +// +// NOTE: May be returned by any node in the payment route. +type FailInvalidRealm struct{} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailInvalidRealm) Error() string { + return f.Code().String() +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailInvalidRealm) Code() FailCode { + return CodeInvalidRealm +} + +// FailTemporaryNodeFailure is returned if an otherwise unspecified transient +// error occurs for the entire node. +// +// NOTE: May be returned by any node in the payment route. +type FailTemporaryNodeFailure struct{} + +// Code returns the failure unique code. +// NOTE: Part of the FailureMessage interface. +func (f *FailTemporaryNodeFailure) Code() FailCode { + return CodeTemporaryNodeFailure +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailTemporaryNodeFailure) Error() string { + return f.Code().String() +} + +// FailPermanentNodeFailure is returned if an otherwise unspecified permanent +// error occurs for the entire node. +// +// NOTE: May be returned by any node in the payment route. +type FailPermanentNodeFailure struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailPermanentNodeFailure) Code() FailCode { + return CodePermanentNodeFailure +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailPermanentNodeFailure) Error() string { + return f.Code().String() +} + +// FailRequiredNodeFeatureMissing is returned if a node has requirement +// advertised in its node_announcement features which were not present in the +// onion. +// +// NOTE: May be returned by any node in the payment route. +type FailRequiredNodeFeatureMissing struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailRequiredNodeFeatureMissing) Code() FailCode { + return CodeRequiredNodeFeatureMissing +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailRequiredNodeFeatureMissing) Error() string { + return f.Code().String() +} + +// FailPermanentChannelFailure is return if an otherwise unspecified permanent +// error occurs for the outgoing channel (eg. channel (recently). +// +// NOTE: May be returned by any node in the payment route. +type FailPermanentChannelFailure struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailPermanentChannelFailure) Code() FailCode { + return CodePermanentChannelFailure +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailPermanentChannelFailure) Error() string { + return f.Code().String() +} + +// FailRequiredChannelFeatureMissing is returned if the outgoing channel has a +// requirement advertised in its channel announcement features which were not +// present in the onion. +// +// NOTE: May only be returned by intermediate nodes. +type FailRequiredChannelFeatureMissing struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailRequiredChannelFeatureMissing) Code() FailCode { + return CodeRequiredChannelFeatureMissing +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailRequiredChannelFeatureMissing) Error() string { + return f.Code().String() +} + +// FailUnknownNextPeer is returned if the next peer specified by the onion is +// not known. +// +// NOTE: May only be returned by intermediate nodes. +type FailUnknownNextPeer struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailUnknownNextPeer) Code() FailCode { + return CodeUnknownNextPeer +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailUnknownNextPeer) Error() string { + return f.Code().String() +} + +// FailIncorrectPaymentAmount is returned if the amount paid is less than the +// amount expected, the final node MUST fail the HTLC. If the amount paid is +// more than twice the amount expected, the final node SHOULD fail the HTLC. +// This allows the sender to reduce information leakage by altering the amount, +// without allowing accidental gross overpayment. +// +// NOTE: May only be returned by the final node in the path. +type FailIncorrectPaymentAmount struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailIncorrectPaymentAmount) Code() FailCode { + return CodeIncorrectPaymentAmount +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailIncorrectPaymentAmount) Error() string { + return f.Code().String() +} + +// FailIncorrectDetails is returned for two reasons: +// +// 1) if the payment hash has already been paid, the final node MAY treat the +// payment hash as unknown, or may succeed in accepting the HTLC. If the +// payment hash is unknown, the final node MUST fail the HTLC. +// +// 2) if the amount paid is less than the amount expected, the final node MUST +// fail the HTLC. If the amount paid is more than twice the amount expected, +// the final node SHOULD fail the HTLC. This allows the sender to reduce +// information leakage by altering the amount, without allowing accidental +// gross overpayment. +// +// NOTE: May only be returned by the final node in the path. +type FailIncorrectDetails struct { + // amount is the value of the extended HTLC. + amount MilliSatoshi + + // height is the block height when the htlc was received. + height uint32 +} + +// NewFailIncorrectDetails makes a new instance of the FailIncorrectDetails +// error bound to the specified HTLC amount and acceptance height. +func NewFailIncorrectDetails(amt MilliSatoshi, + height uint32) *FailIncorrectDetails { + + return &FailIncorrectDetails{ + amount: amt, + height: height, + } +} + +// Amount is the value of the extended HTLC. +func (f *FailIncorrectDetails) Amount() MilliSatoshi { + return f.amount +} + +// Height is the block height when the htlc was received. +func (f *FailIncorrectDetails) Height() uint32 { + return f.height +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailIncorrectDetails) Code() FailCode { + return CodeIncorrectOrUnknownPaymentDetails +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailIncorrectDetails) Error() string { + return fmt.Sprintf( + "%v(amt=%v, height=%v)", CodeIncorrectOrUnknownPaymentDetails, + f.amount, f.height, + ) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailIncorrectDetails) Decode(r io.Reader, pver uint32) error { + err := ReadElement(r, &f.amount) + switch { + // This is an optional tack on that was added later in the protocol. As + // a result, older nodes may not include this value. We'll account for + // this by checking for io.EOF here which means that no bytes were read + // at all. + case err == io.EOF: + return nil + + case err != nil: + return err + } + + // At a later stage, the height field was also tacked on. We need to + // check for io.EOF here as well. + err = ReadElement(r, &f.height) + switch { + case err == io.EOF: + return nil + + case err != nil: + return err + } + + return nil +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailIncorrectDetails) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, f.amount, f.height) +} + +// FailFinalExpiryTooSoon is returned if the cltv_expiry is too low, the final +// node MUST fail the HTLC. +// +// NOTE: May only be returned by the final node in the path. +type FailFinalExpiryTooSoon struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailFinalExpiryTooSoon) Code() FailCode { + return CodeFinalExpiryTooSoon +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailFinalExpiryTooSoon) Error() string { + return f.Code().String() +} + +// NewFinalExpiryTooSoon creates new instance of the FailFinalExpiryTooSoon. +func NewFinalExpiryTooSoon() *FailFinalExpiryTooSoon { + return &FailFinalExpiryTooSoon{} +} + +// FailInvalidOnionVersion is returned if the onion version byte is unknown. +// +// NOTE: May be returned only by intermediate nodes. +type FailInvalidOnionVersion struct { + // OnionSHA256 hash of the onion blob which haven't been proceeded. + OnionSHA256 [sha256.Size]byte +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailInvalidOnionVersion) Error() string { + return fmt.Sprintf("InvalidOnionVersion(onion_sha=%x)", f.OnionSHA256[:]) +} + +// NewInvalidOnionVersion creates new instance of the FailInvalidOnionVersion. +func NewInvalidOnionVersion(onion []byte) *FailInvalidOnionVersion { + return &FailInvalidOnionVersion{OnionSHA256: sha256.Sum256(onion)} +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailInvalidOnionVersion) Code() FailCode { + return CodeInvalidOnionVersion +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailInvalidOnionVersion) Decode(r io.Reader, pver uint32) error { + return ReadElement(r, f.OnionSHA256[:]) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailInvalidOnionVersion) Encode(w io.Writer, pver uint32) error { + return WriteElement(w, f.OnionSHA256[:]) +} + +// FailInvalidOnionHmac is return if the onion HMAC is incorrect. +// +// NOTE: May only be returned by intermediate nodes. +type FailInvalidOnionHmac struct { + // OnionSHA256 hash of the onion blob which haven't been proceeded. + OnionSHA256 [sha256.Size]byte +} + +// NewInvalidOnionHmac creates new instance of the FailInvalidOnionHmac. +func NewInvalidOnionHmac(onion []byte) *FailInvalidOnionHmac { + return &FailInvalidOnionHmac{OnionSHA256: sha256.Sum256(onion)} +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailInvalidOnionHmac) Code() FailCode { + return CodeInvalidOnionHmac +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailInvalidOnionHmac) Decode(r io.Reader, pver uint32) error { + return ReadElement(r, f.OnionSHA256[:]) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailInvalidOnionHmac) Encode(w io.Writer, pver uint32) error { + return WriteElement(w, f.OnionSHA256[:]) +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailInvalidOnionHmac) Error() string { + return fmt.Sprintf("InvalidOnionHMAC(onion_sha=%x)", f.OnionSHA256[:]) +} + +// FailInvalidOnionKey is return if the ephemeral key in the onion is +// unparsable. +// +// NOTE: May only be returned by intermediate nodes. +type FailInvalidOnionKey struct { + // OnionSHA256 hash of the onion blob which haven't been proceeded. + OnionSHA256 [sha256.Size]byte +} + +// NewInvalidOnionKey creates new instance of the FailInvalidOnionKey. +func NewInvalidOnionKey(onion []byte) *FailInvalidOnionKey { + return &FailInvalidOnionKey{OnionSHA256: sha256.Sum256(onion)} +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailInvalidOnionKey) Code() FailCode { + return CodeInvalidOnionKey +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailInvalidOnionKey) Decode(r io.Reader, pver uint32) error { + return ReadElement(r, f.OnionSHA256[:]) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailInvalidOnionKey) Encode(w io.Writer, pver uint32) error { + return WriteElement(w, f.OnionSHA256[:]) +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailInvalidOnionKey) Error() string { + return fmt.Sprintf("InvalidOnionKey(onion_sha=%x)", f.OnionSHA256[:]) +} + +// parseChannelUpdateCompatabilityMode will attempt to parse a channel updated +// encoded into an onion error payload in two ways. First, we'll try the +// compatibility oriented version wherein we'll _skip_ the length prefixing on +// the channel update message. Older versions of c-lighting do this so we'll +// attempt to parse these messages in order to retain compatibility. If we're +// unable to pull out a fully valid version, then we'll fall back to the +// regular parsing mechanism which includes the length prefix an NO type byte. +func parseChannelUpdateCompatabilityMode(r *bufio.Reader, + chanUpdate *ChannelUpdate, pver uint32) error { + + // We'll peek out two bytes from the buffer without advancing the + // buffer so we can decide how to parse the remainder of it. + maybeTypeBytes, err := r.Peek(2) + if err != nil { + return err + } + + // Some nodes well prefix an additional set of bytes in front of their + // channel updates. These bytes will _almost_ always be 258 or the type + // of the ChannelUpdate message. + typeInt := binary.BigEndian.Uint16(maybeTypeBytes) + if typeInt == MsgChannelUpdate { + // At this point it's likely the case that this is a channel + // update message with its type prefixed, so we'll snip off the + // first two bytes and parse it as normal. + var throwAwayTypeBytes [2]byte + _, err := r.Read(throwAwayTypeBytes[:]) + if err != nil { + return err + } + } + + // At this pint, we've either decided to keep the entire thing, or snip + // off the first two bytes. In either case, we can just read it as + // normal. + return chanUpdate.Decode(r, pver) +} + +// FailTemporaryChannelFailure is if an otherwise unspecified transient error +// occurs for the outgoing channel (eg. channel capacity reached, too many +// in-flight htlcs) +// +// NOTE: May only be returned by intermediate nodes. +type FailTemporaryChannelFailure struct { + // Update is used to update information about state of the channel + // which caused the failure. + // + // NOTE: This field is optional. + Update *ChannelUpdate +} + +// NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure. +func NewTemporaryChannelFailure(update *ChannelUpdate) *FailTemporaryChannelFailure { + return &FailTemporaryChannelFailure{Update: update} +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailTemporaryChannelFailure) Code() FailCode { + return CodeTemporaryChannelFailure +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailTemporaryChannelFailure) Error() string { + if f.Update == nil { + return f.Code().String() + } + + return fmt.Sprintf("TemporaryChannelFailure(update=%v)", + spew.Sdump(f.Update)) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { + var length uint16 + err := ReadElement(r, &length) + if err != nil { + return err + } + + if length != 0 { + f.Update = &ChannelUpdate{} + return parseChannelUpdateCompatabilityMode( + bufio.NewReader(r), f.Update, pver, + ) + } + + return nil +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailTemporaryChannelFailure) Encode(w io.Writer, pver uint32) error { + var payload []byte + if f.Update != nil { + var bw bytes.Buffer + if err := f.Update.Encode(&bw, pver); err != nil { + return err + } + payload = bw.Bytes() + } + + if err := WriteElement(w, uint16(len(payload))); err != nil { + return err + } + + _, err := w.Write(payload) + return err +} + +// FailAmountBelowMinimum is returned if the HTLC does not reach the current +// minimum amount, we tell them the amount of the incoming HTLC and the current +// channel setting for the outgoing channel. +// +// NOTE: May only be returned by the intermediate nodes in the path. +type FailAmountBelowMinimum struct { + // HtlcMsat is the wrong amount of the incoming HTLC. + HtlcMsat MilliSatoshi + + // Update is used to update information about state of the channel + // which caused the failure. + Update ChannelUpdate +} + +// NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum. +func NewAmountBelowMinimum(htlcMsat MilliSatoshi, + update ChannelUpdate) *FailAmountBelowMinimum { + + return &FailAmountBelowMinimum{ + HtlcMsat: htlcMsat, + Update: update, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailAmountBelowMinimum) Code() FailCode { + return CodeAmountBelowMinimum +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailAmountBelowMinimum) Error() string { + return fmt.Sprintf("AmountBelowMinimum(amt=%v, update=%v", f.HtlcMsat, + spew.Sdump(f.Update)) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { + if err := ReadElement(r, &f.HtlcMsat); err != nil { + return err + } + + var length uint16 + if err := ReadElement(r, &length); err != nil { + return err + } + + f.Update = ChannelUpdate{} + return parseChannelUpdateCompatabilityMode( + bufio.NewReader(r), &f.Update, pver, + ) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error { + if err := WriteElement(w, f.HtlcMsat); err != nil { + return err + } + + return writeOnionErrorChanUpdate(w, &f.Update, pver) +} + +// FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we +// tell them the amount of the incoming HTLC and the current channel setting +// for the outgoing channel. +// +// NOTE: May only be returned by intermediate nodes. +type FailFeeInsufficient struct { + // HtlcMsat is the wrong amount of the incoming HTLC. + HtlcMsat MilliSatoshi + + // Update is used to update information about state of the channel + // which caused the failure. + Update ChannelUpdate +} + +// NewFeeInsufficient creates new instance of the FailFeeInsufficient. +func NewFeeInsufficient(htlcMsat MilliSatoshi, + update ChannelUpdate) *FailFeeInsufficient { + return &FailFeeInsufficient{ + HtlcMsat: htlcMsat, + Update: update, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailFeeInsufficient) Code() FailCode { + return CodeFeeInsufficient +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailFeeInsufficient) Error() string { + return fmt.Sprintf("FeeInsufficient(htlc_amt==%v, update=%v", f.HtlcMsat, + spew.Sdump(f.Update)) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { + if err := ReadElement(r, &f.HtlcMsat); err != nil { + return err + } + + var length uint16 + if err := ReadElement(r, &length); err != nil { + return err + } + + f.Update = ChannelUpdate{} + return parseChannelUpdateCompatabilityMode( + bufio.NewReader(r), &f.Update, pver, + ) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error { + if err := WriteElement(w, f.HtlcMsat); err != nil { + return err + } + + return writeOnionErrorChanUpdate(w, &f.Update, pver) +} + +// FailIncorrectCltvExpiry is returned if outgoing cltv value does not match +// the update add htlc's cltv expiry minus cltv expiry delta for the outgoing +// channel, we tell them the cltv expiry and the current channel setting for +// the outgoing channel. +// +// NOTE: May only be returned by intermediate nodes. +type FailIncorrectCltvExpiry struct { + // CltvExpiry is the wrong absolute timeout in blocks, after which + // outgoing HTLC expires. + CltvExpiry uint32 + + // Update is used to update information about state of the channel + // which caused the failure. + Update ChannelUpdate +} + +// NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry. +func NewIncorrectCltvExpiry(cltvExpiry uint32, + update ChannelUpdate) *FailIncorrectCltvExpiry { + + return &FailIncorrectCltvExpiry{ + CltvExpiry: cltvExpiry, + Update: update, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailIncorrectCltvExpiry) Code() FailCode { + return CodeIncorrectCltvExpiry +} + +func (f *FailIncorrectCltvExpiry) Error() string { + return fmt.Sprintf("IncorrectCltvExpiry(expiry=%v, update=%v", + f.CltvExpiry, spew.Sdump(f.Update)) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { + if err := ReadElement(r, &f.CltvExpiry); err != nil { + return err + } + + var length uint16 + if err := ReadElement(r, &length); err != nil { + return err + } + + f.Update = ChannelUpdate{} + return parseChannelUpdateCompatabilityMode( + bufio.NewReader(r), &f.Update, pver, + ) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { + if err := WriteElement(w, f.CltvExpiry); err != nil { + return err + } + + return writeOnionErrorChanUpdate(w, &f.Update, pver) +} + +// FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them +// the current channel setting for the outgoing channel. +// +// NOTE: May only be returned by intermediate nodes. +type FailExpiryTooSoon struct { + // Update is used to update information about state of the channel + // which caused the failure. + Update ChannelUpdate +} + +// NewExpiryTooSoon creates new instance of the FailExpiryTooSoon. +func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon { + return &FailExpiryTooSoon{ + Update: update, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailExpiryTooSoon) Code() FailCode { + return CodeExpiryTooSoon +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailExpiryTooSoon) Error() string { + return fmt.Sprintf("ExpiryTooSoon(update=%v", spew.Sdump(f.Update)) +} + +// Decode decodes the failure from l stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { + var length uint16 + if err := ReadElement(r, &length); err != nil { + return err + } + + f.Update = ChannelUpdate{} + return parseChannelUpdateCompatabilityMode( + bufio.NewReader(r), &f.Update, pver, + ) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error { + return writeOnionErrorChanUpdate(w, &f.Update, pver) +} + +// FailChannelDisabled is returned if the channel is disabled, we tell them the +// current channel setting for the outgoing channel. +// +// NOTE: May only be returned by intermediate nodes. +type FailChannelDisabled struct { + // Flags least-significant bit must be set to 0 if the creating node + // corresponds to the first node in the previously sent channel + // announcement and 1 otherwise. + Flags uint16 + + // Update is used to update information about state of the channel + // which caused the failure. + Update ChannelUpdate +} + +// NewChannelDisabled creates new instance of the FailChannelDisabled. +func NewChannelDisabled(flags uint16, update ChannelUpdate) *FailChannelDisabled { + return &FailChannelDisabled{ + Flags: flags, + Update: update, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailChannelDisabled) Code() FailCode { + return CodeChannelDisabled +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailChannelDisabled) Error() string { + return fmt.Sprintf("ChannelDisabled(flags=%v, update=%v", f.Flags, + spew.Sdump(f.Update)) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { + if err := ReadElement(r, &f.Flags); err != nil { + return err + } + + var length uint16 + if err := ReadElement(r, &length); err != nil { + return err + } + + f.Update = ChannelUpdate{} + return parseChannelUpdateCompatabilityMode( + bufio.NewReader(r), &f.Update, pver, + ) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error { + if err := WriteElement(w, f.Flags); err != nil { + return err + } + + return writeOnionErrorChanUpdate(w, &f.Update, pver) +} + +// FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not +// match the ctlv_expiry of the HTLC at the final hop. +// +// NOTE: might be returned by final node only. +type FailFinalIncorrectCltvExpiry struct { + // CltvExpiry is the wrong absolute timeout in blocks, after which + // outgoing HTLC expires. + CltvExpiry uint32 +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailFinalIncorrectCltvExpiry) Error() string { + return fmt.Sprintf("FinalIncorrectCltvExpiry(expiry=%v)", f.CltvExpiry) +} + +// NewFinalIncorrectCltvExpiry creates new instance of the +// FailFinalIncorrectCltvExpiry. +func NewFinalIncorrectCltvExpiry(cltvExpiry uint32) *FailFinalIncorrectCltvExpiry { + return &FailFinalIncorrectCltvExpiry{ + CltvExpiry: cltvExpiry, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailFinalIncorrectCltvExpiry) Code() FailCode { + return CodeFinalIncorrectCltvExpiry +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailFinalIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { + return ReadElement(r, &f.CltvExpiry) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailFinalIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { + return WriteElement(w, f.CltvExpiry) +} + +// FailFinalIncorrectHtlcAmount is returned if the amt_to_forward is higher +// than incoming_htlc_amt of the HTLC at the final hop. +// +// NOTE: May only be returned by the final node. +type FailFinalIncorrectHtlcAmount struct { + // IncomingHTLCAmount is the wrong forwarded htlc amount. + IncomingHTLCAmount MilliSatoshi +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailFinalIncorrectHtlcAmount) Error() string { + return fmt.Sprintf("FinalIncorrectHtlcAmount(amt=%v)", + f.IncomingHTLCAmount) +} + +// NewFinalIncorrectHtlcAmount creates new instance of the +// FailFinalIncorrectHtlcAmount. +func NewFinalIncorrectHtlcAmount(amount MilliSatoshi) *FailFinalIncorrectHtlcAmount { + return &FailFinalIncorrectHtlcAmount{ + IncomingHTLCAmount: amount, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailFinalIncorrectHtlcAmount) Code() FailCode { + return CodeFinalIncorrectHtlcAmount +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailFinalIncorrectHtlcAmount) Decode(r io.Reader, pver uint32) error { + return ReadElement(r, &f.IncomingHTLCAmount) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *FailFinalIncorrectHtlcAmount) Encode(w io.Writer, pver uint32) error { + return WriteElement(w, f.IncomingHTLCAmount) +} + +// FailExpiryTooFar is returned if the CLTV expiry in the HTLC is too far in the +// future. +// +// NOTE: May be returned by any node in the payment route. +type FailExpiryTooFar struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailExpiryTooFar) Code() FailCode { + return CodeExpiryTooFar +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailExpiryTooFar) Error() string { + return f.Code().String() +} + +// InvalidOnionPayload is returned if the hop could not process the TLV payload +// enclosed in the onion. +type InvalidOnionPayload struct { + // Type is the TLV type that caused the specific failure. + Type uint64 + + // Offset is the byte offset within the payload where the failure + // occurred. + Offset uint16 +} + +// NewInvalidOnionPayload initializes a new InvalidOnionPayload failure. +func NewInvalidOnionPayload(typ uint64, offset uint16) *InvalidOnionPayload { + return &InvalidOnionPayload{ + Type: typ, + Offset: offset, + } +} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *InvalidOnionPayload) Code() FailCode { + return CodeInvalidOnionPayload +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *InvalidOnionPayload) Error() string { + return fmt.Sprintf("%v(type=%v, offset=%d)", + f.Code(), f.Type, f.Offset) +} + +// Decode decodes the failure from bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *InvalidOnionPayload) Decode(r io.Reader, pver uint32) error { + var buf [8]byte + typ, err := tlv.ReadVarInt(r, &buf) + if err != nil { + return err + } + f.Type = typ + + return ReadElements(r, &f.Offset) +} + +// Encode writes the failure in bytes stream. +// +// NOTE: Part of the Serializable interface. +func (f *InvalidOnionPayload) Encode(w io.Writer, pver uint32) error { + var buf [8]byte + if err := tlv.WriteVarInt(w, f.Type, &buf); err != nil { + return err + } + + return WriteElements(w, f.Offset) +} + +// FailMPPTimeout is returned if the complete amount for a multi part payment +// was not received within a reasonable time. +// +// NOTE: May only be returned by the final node in the path. +type FailMPPTimeout struct{} + +// Code returns the failure unique code. +// +// NOTE: Part of the FailureMessage interface. +func (f *FailMPPTimeout) Code() FailCode { + return CodeMPPTimeout +} + +// Returns a human readable string describing the target FailureMessage. +// +// NOTE: Implements the error interface. +func (f *FailMPPTimeout) Error() string { + return f.Code().String() +} + +// DecodeFailure decodes, validates, and parses the lnwire onion failure, for +// the provided protocol version. +func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { + // First, we'll parse out the encapsulated failure message itself. This + // is a 2 byte length followed by the payload itself. + var failureLength uint16 + if err := ReadElement(r, &failureLength); err != nil { + return nil, fmt.Errorf("unable to read error len: %v", err) + } + if failureLength > FailureMessageLength { + return nil, fmt.Errorf("failure message is too "+ + "long: %v", failureLength) + } + failureData := make([]byte, failureLength) + if _, err := io.ReadFull(r, failureData); err != nil { + return nil, fmt.Errorf("unable to full read payload of "+ + "%v: %v", failureLength, err) + } + + dataReader := bytes.NewReader(failureData) + + return DecodeFailureMessage(dataReader, pver) +} + +// DecodeFailureMessage decodes just the failure message, ignoring any padding +// that may be present at the end. +func DecodeFailureMessage(r io.Reader, pver uint32) (FailureMessage, error) { + // Once we have the failure data, we can obtain the failure code from + // the first two bytes of the buffer. + var codeBytes [2]byte + if _, err := io.ReadFull(r, codeBytes[:]); err != nil { + return nil, fmt.Errorf("unable to read failure code: %v", err) + } + failCode := FailCode(binary.BigEndian.Uint16(codeBytes[:])) + + // Create the empty failure by given code and populate the failure with + // additional data if needed. + failure, err := makeEmptyOnionError(failCode) + if err != nil { + return nil, fmt.Errorf("unable to make empty error: %v", err) + } + + // Finally, if this failure has a payload, then we'll read that now as + // well. + switch f := failure.(type) { + case Serializable: + if err := f.Decode(r, pver); err != nil { + return nil, fmt.Errorf("unable to decode error "+ + "update (type=%T): %v", failure, err) + } + } + + return failure, nil +} + +// EncodeFailure encodes, including the necessary onion failure header +// information. +func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { + var failureMessageBuffer bytes.Buffer + + err := EncodeFailureMessage(&failureMessageBuffer, failure, pver) + if err != nil { + return err + } + + // The combined size of this message must be below the max allowed + // failure message length. + failureMessage := failureMessageBuffer.Bytes() + if len(failureMessage) > FailureMessageLength { + return fmt.Errorf("failure message exceed max "+ + "available size: %v", len(failureMessage)) + } + + // Finally, we'll add some padding in order to ensure that all failure + // messages are fixed size. + pad := make([]byte, FailureMessageLength-len(failureMessage)) + + return WriteElements(w, + uint16(len(failureMessage)), + failureMessage, + uint16(len(pad)), + pad, + ) +} + +// EncodeFailureMessage encodes just the failure message without adding a length +// and padding the message for the onion protocol. +func EncodeFailureMessage(w io.Writer, failure FailureMessage, pver uint32) error { + // First, we'll write out the error code itself into the failure + // buffer. + var codeBytes [2]byte + code := uint16(failure.Code()) + binary.BigEndian.PutUint16(codeBytes[:], code) + _, err := w.Write(codeBytes[:]) + if err != nil { + return err + } + + // Next, some message have an additional message payload, if this is + // one of those types, then we'll also encode the error payload as + // well. + switch failure := failure.(type) { + case Serializable: + if err := failure.Encode(w, pver); err != nil { + return err + } + } + + return nil +} + +// makeEmptyOnionError creates a new empty onion error of the proper concrete +// type based on the passed failure code. +func makeEmptyOnionError(code FailCode) (FailureMessage, error) { + switch code { + case CodeInvalidRealm: + return &FailInvalidRealm{}, nil + + case CodeTemporaryNodeFailure: + return &FailTemporaryNodeFailure{}, nil + + case CodePermanentNodeFailure: + return &FailPermanentNodeFailure{}, nil + + case CodeRequiredNodeFeatureMissing: + return &FailRequiredNodeFeatureMissing{}, nil + + case CodePermanentChannelFailure: + return &FailPermanentChannelFailure{}, nil + + case CodeRequiredChannelFeatureMissing: + return &FailRequiredChannelFeatureMissing{}, nil + + case CodeUnknownNextPeer: + return &FailUnknownNextPeer{}, nil + + case CodeIncorrectOrUnknownPaymentDetails: + return &FailIncorrectDetails{}, nil + + case CodeIncorrectPaymentAmount: + return &FailIncorrectPaymentAmount{}, nil + + case CodeFinalExpiryTooSoon: + return &FailFinalExpiryTooSoon{}, nil + + case CodeInvalidOnionVersion: + return &FailInvalidOnionVersion{}, nil + + case CodeInvalidOnionHmac: + return &FailInvalidOnionHmac{}, nil + + case CodeInvalidOnionKey: + return &FailInvalidOnionKey{}, nil + + case CodeTemporaryChannelFailure: + return &FailTemporaryChannelFailure{}, nil + + case CodeAmountBelowMinimum: + return &FailAmountBelowMinimum{}, nil + + case CodeFeeInsufficient: + return &FailFeeInsufficient{}, nil + + case CodeIncorrectCltvExpiry: + return &FailIncorrectCltvExpiry{}, nil + + case CodeExpiryTooSoon: + return &FailExpiryTooSoon{}, nil + + case CodeChannelDisabled: + return &FailChannelDisabled{}, nil + + case CodeFinalIncorrectCltvExpiry: + return &FailFinalIncorrectCltvExpiry{}, nil + + case CodeFinalIncorrectHtlcAmount: + return &FailFinalIncorrectHtlcAmount{}, nil + + case CodeExpiryTooFar: + return &FailExpiryTooFar{}, nil + + case CodeInvalidOnionPayload: + return &InvalidOnionPayload{}, nil + + case CodeMPPTimeout: + return &FailMPPTimeout{}, nil + + default: + return nil, errors.Errorf("unknown error code: %v", code) + } +} + +// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error +// format. The format is that we first write out the true serialized length of +// the channel update, followed by the serialized channel update itself. +func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate, + pver uint32) error { + + // First, we encode the channel update in a temporary buffer in order + // to get the exact serialized size. + var b bytes.Buffer + if err := chanUpdate.Encode(&b, pver); err != nil { + return err + } + + // Now that we know the size, we can write the length out in the main + // writer. + updateLen := b.Len() + if err := WriteElement(w, uint16(updateLen)); err != nil { + return err + } + + // With the length written, we'll then write out the serialized channel + // update. + if _, err := w.Write(b.Bytes()); err != nil { + return err + } + + return nil +} diff --git a/channeldb/migration/lnwire21/open_channel.go b/channeldb/migration/lnwire21/open_channel.go new file mode 100644 index 00000000..a165ef75 --- /dev/null +++ b/channeldb/migration/lnwire21/open_channel.go @@ -0,0 +1,225 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcutil" +) + +// FundingFlag represents the possible bit mask values for the ChannelFlags +// field within the OpenChannel struct. +type FundingFlag uint8 + +const ( + // FFAnnounceChannel is a FundingFlag that when set, indicates the + // initiator of a funding flow wishes to announce the channel to the + // greater network. + FFAnnounceChannel FundingFlag = 1 << iota +) + +// OpenChannel is the message Alice sends to Bob if we should like to create a +// channel with Bob where she's the sole provider of funds to the channel. +// Single funder channels simplify the initial funding workflow, are supported +// by nodes backed by SPV Bitcoin clients, and have a simpler security models +// than dual funded channels. +type OpenChannel struct { + // ChainHash is the target chain that the initiator wishes to open a + // channel within. + ChainHash chainhash.Hash + + // PendingChannelID serves to uniquely identify the future channel + // created by the initiated single funder workflow. + PendingChannelID [32]byte + + // FundingAmount is the amount of satoshis that the initiator of the + // channel wishes to use as the total capacity of the channel. The + // initial balance of the funding will be this value minus the push + // amount (if set). + FundingAmount btcutil.Amount + + // PushAmount is the value that the initiating party wishes to "push" + // to the responding as part of the first commitment state. If the + // responder accepts, then this will be their initial balance. + PushAmount MilliSatoshi + + // DustLimit is the specific dust limit the sender of this message + // would like enforced on their version of the commitment transaction. + // Any output below this value will be "trimmed" from the commitment + // transaction, with the amount of the HTLC going to dust. + DustLimit btcutil.Amount + + // MaxValueInFlight represents the maximum amount of coins that can be + // pending within the channel at any given time. If the amount of funds + // in limbo exceeds this amount, then the channel will be failed. + MaxValueInFlight MilliSatoshi + + // ChannelReserve is the amount of BTC that the receiving party MUST + // maintain a balance above at all times. This is a safety mechanism to + // ensure that both sides always have skin in the game during the + // channel's lifetime. + ChannelReserve btcutil.Amount + + // HtlcMinimum is the smallest HTLC that the sender of this message + // will accept. + HtlcMinimum MilliSatoshi + + // FeePerKiloWeight is the initial fee rate that the initiator suggests + // for both commitment transaction. This value is expressed in sat per + // kilo-weight. + // + // TODO(halseth): make SatPerKWeight when fee estimation is in own + // package. Currently this will cause an import cycle. + FeePerKiloWeight uint32 + + // CsvDelay is the number of blocks to use for the relative time lock + // in the pay-to-self output of both commitment transactions. + CsvDelay uint16 + + // MaxAcceptedHTLCs is the total number of incoming HTLC's that the + // sender of this channel will accept. + MaxAcceptedHTLCs uint16 + + // FundingKey is the key that should be used on behalf of the sender + // within the 2-of-2 multi-sig output that it contained within the + // funding transaction. + FundingKey *btcec.PublicKey + + // RevocationPoint is the base revocation point for the sending party. + // Any commitment transaction belonging to the receiver of this message + // should use this key and their per-commitment point to derive the + // revocation key for the commitment transaction. + RevocationPoint *btcec.PublicKey + + // PaymentPoint is the base payment point for the sending party. This + // key should be combined with the per commitment point for a + // particular commitment state in order to create the key that should + // be used in any output that pays directly to the sending party, and + // also within the HTLC covenant transactions. + PaymentPoint *btcec.PublicKey + + // DelayedPaymentPoint is the delay point for the sending party. This + // key should be combined with the per commitment point to derive the + // keys that are used in outputs of the sender's commitment transaction + // where they claim funds. + DelayedPaymentPoint *btcec.PublicKey + + // HtlcPoint is the base point used to derive the set of keys for this + // party that will be used within the HTLC public key scripts. This + // value is combined with the receiver's revocation base point in order + // to derive the keys that are used within HTLC scripts. + HtlcPoint *btcec.PublicKey + + // FirstCommitmentPoint is the first commitment point for the sending + // party. This value should be combined with the receiver's revocation + // base point in order to derive the revocation keys that are placed + // within the commitment transaction of the sender. + FirstCommitmentPoint *btcec.PublicKey + + // ChannelFlags is a bit-field which allows the initiator of the + // channel to specify further behavior surrounding the channel. + // Currently, the least significant bit of this bit field indicates the + // initiator of the channel wishes to advertise this channel publicly. + ChannelFlags FundingFlag + + // UpfrontShutdownScript is the script to which the channel funds should + // be paid when mutually closing the channel. This field is optional, and + // and has a length prefix, so a zero will be written if it is not set + // and its length followed by the script will be written if it is set. + UpfrontShutdownScript DeliveryAddress +} + +// A compile time check to ensure OpenChannel implements the lnwire.Message +// interface. +var _ Message = (*OpenChannel)(nil) + +// Encode serializes the target OpenChannel into the passed io.Writer +// implementation. Serialization will observe the rules defined by the passed +// protocol version. +// +// This is part of the lnwire.Message interface. +func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + o.ChainHash[:], + o.PendingChannelID[:], + o.FundingAmount, + o.PushAmount, + o.DustLimit, + o.MaxValueInFlight, + o.ChannelReserve, + o.HtlcMinimum, + o.FeePerKiloWeight, + o.CsvDelay, + o.MaxAcceptedHTLCs, + o.FundingKey, + o.RevocationPoint, + o.PaymentPoint, + o.DelayedPaymentPoint, + o.HtlcPoint, + o.FirstCommitmentPoint, + o.ChannelFlags, + o.UpfrontShutdownScript, + ) +} + +// Decode deserializes the serialized OpenChannel stored in the passed +// io.Reader into the target OpenChannel using the deserialization rules +// defined by the passed protocol version. +// +// This is part of the lnwire.Message interface. +func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { + if err := ReadElements(r, + o.ChainHash[:], + o.PendingChannelID[:], + &o.FundingAmount, + &o.PushAmount, + &o.DustLimit, + &o.MaxValueInFlight, + &o.ChannelReserve, + &o.HtlcMinimum, + &o.FeePerKiloWeight, + &o.CsvDelay, + &o.MaxAcceptedHTLCs, + &o.FundingKey, + &o.RevocationPoint, + &o.PaymentPoint, + &o.DelayedPaymentPoint, + &o.HtlcPoint, + &o.FirstCommitmentPoint, + &o.ChannelFlags, + ); err != nil { + return err + } + + // Check for the optional upfront shutdown script field. If it is not there, + // silence the EOF error. + err := ReadElement(r, &o.UpfrontShutdownScript) + if err != nil && err != io.EOF { + return err + } + + return nil +} + +// MsgType returns the MessageType code which uniquely identifies this message +// as an OpenChannel on the wire. +// +// This is part of the lnwire.Message interface. +func (o *OpenChannel) MsgType() MessageType { + return MsgOpenChannel +} + +// MaxPayloadLength returns the maximum allowed payload length for a +// OpenChannel message. +// +// This is part of the lnwire.Message interface. +func (o *OpenChannel) MaxPayloadLength(uint32) uint32 { + // (32 * 2) + (8 * 6) + (4 * 1) + (2 * 2) + (33 * 6) + 1 + var length uint32 = 319 // base length + + // Upfront shutdown script max length. + length += 2 + deliveryAddressMaxSize + + return length +} diff --git a/channeldb/migration/lnwire21/ping.go b/channeldb/migration/lnwire21/ping.go new file mode 100644 index 00000000..cf9a83b7 --- /dev/null +++ b/channeldb/migration/lnwire21/ping.go @@ -0,0 +1,67 @@ +package lnwire + +import "io" + +// PingPayload is a set of opaque bytes used to pad out a ping message. +type PingPayload []byte + +// Ping defines a message which is sent by peers periodically to determine if +// the connection is still valid. Each ping message carries the number of bytes +// to pad the pong response with, and also a number of bytes to be ignored at +// the end of the ping message (which is padding). +type Ping struct { + // NumPongBytes is the number of bytes the pong response to this + // message should carry. + NumPongBytes uint16 + + // PaddingBytes is a set of opaque bytes used to pad out this ping + // message. Using this field in conjunction to the one above, it's + // possible for node to generate fake cover traffic. + PaddingBytes PingPayload +} + +// NewPing returns a new Ping message. +func NewPing(numBytes uint16) *Ping { + return &Ping{ + NumPongBytes: numBytes, + } +} + +// A compile time check to ensure Ping implements the lnwire.Message interface. +var _ Message = (*Ping)(nil) + +// Decode deserializes a serialized Ping message stored in the passed io.Reader +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (p *Ping) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &p.NumPongBytes, + &p.PaddingBytes) +} + +// Encode serializes the target Ping into the passed io.Writer observing the +// protocol version specified. +// +// This is part of the lnwire.Message interface. +func (p *Ping) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + p.NumPongBytes, + p.PaddingBytes) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (p *Ping) MsgType() MessageType { + return MsgPing +} + +// MaxPayloadLength returns the maximum allowed payload size for a Ping +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (p Ping) MaxPayloadLength(uint32) uint32 { + return 65532 +} diff --git a/channeldb/migration/lnwire21/pong.go b/channeldb/migration/lnwire21/pong.go new file mode 100644 index 00000000..c3166aaf --- /dev/null +++ b/channeldb/migration/lnwire21/pong.go @@ -0,0 +1,63 @@ +package lnwire + +import "io" + +// PongPayload is a set of opaque bytes sent in response to a ping message. +type PongPayload []byte + +// Pong defines a message which is the direct response to a received Ping +// message. A Pong reply indicates that a connection is still active. The Pong +// reply to a Ping message should contain the nonce carried in the original +// Pong message. +type Pong struct { + // PongBytes is a set of opaque bytes that corresponds to the + // NumPongBytes defined in the ping message that this pong is + // replying to. + PongBytes PongPayload +} + +// NewPong returns a new Pong message. +func NewPong(pongBytes []byte) *Pong { + return &Pong{ + PongBytes: pongBytes, + } +} + +// A compile time check to ensure Pong implements the lnwire.Message interface. +var _ Message = (*Pong)(nil) + +// Decode deserializes a serialized Pong message stored in the passed io.Reader +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (p *Pong) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &p.PongBytes, + ) +} + +// Encode serializes the target Pong into the passed io.Writer observing the +// protocol version specified. +// +// This is part of the lnwire.Message interface. +func (p *Pong) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + p.PongBytes, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (p *Pong) MsgType() MessageType { + return MsgPong +} + +// MaxPayloadLength returns the maximum allowed payload size for a Pong +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (p *Pong) MaxPayloadLength(uint32) uint32 { + return 65532 +} diff --git a/channeldb/migration/lnwire21/query_channel_range.go b/channeldb/migration/lnwire21/query_channel_range.go new file mode 100644 index 00000000..9546fcd3 --- /dev/null +++ b/channeldb/migration/lnwire21/query_channel_range.go @@ -0,0 +1,89 @@ +package lnwire + +import ( + "io" + "math" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// QueryChannelRange is a message sent by a node in order to query the +// receiving node of the set of open channel they know of with short channel +// ID's after the specified block height, capped at the number of blocks beyond +// that block height. This will be used by nodes upon initial connect to +// synchronize their views of the network. +type QueryChannelRange struct { + // ChainHash denotes the target chain that we're trying to synchronize + // channel graph state for. + ChainHash chainhash.Hash + + // FirstBlockHeight is the first block in the query range. The + // responder should send all new short channel IDs from this block + // until this block plus the specified number of blocks. + FirstBlockHeight uint32 + + // NumBlocks is the number of blocks beyond the first block that short + // channel ID's should be sent for. + NumBlocks uint32 +} + +// NewQueryChannelRange creates a new empty QueryChannelRange message. +func NewQueryChannelRange() *QueryChannelRange { + return &QueryChannelRange{} +} + +// A compile time check to ensure QueryChannelRange implements the +// lnwire.Message interface. +var _ Message = (*QueryChannelRange)(nil) + +// Decode deserializes a serialized QueryChannelRange message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + q.ChainHash[:], + &q.FirstBlockHeight, + &q.NumBlocks, + ) +} + +// Encode serializes the target QueryChannelRange into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + q.ChainHash[:], + q.FirstBlockHeight, + q.NumBlocks, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) MsgType() MessageType { + return MsgQueryChannelRange +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// QueryChannelRange complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryChannelRange) MaxPayloadLength(uint32) uint32 { + // 32 + 4 + 4 + return 40 +} + +// LastBlockHeight returns the last block height covered by the range of a +// QueryChannelRange message. +func (q *QueryChannelRange) LastBlockHeight() uint32 { + // Handle overflows by casting to uint64. + lastBlockHeight := uint64(q.FirstBlockHeight) + uint64(q.NumBlocks) - 1 + if lastBlockHeight > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(lastBlockHeight) +} diff --git a/channeldb/migration/lnwire21/query_short_chan_ids.go b/channeldb/migration/lnwire21/query_short_chan_ids.go new file mode 100644 index 00000000..3c2b9948 --- /dev/null +++ b/channeldb/migration/lnwire21/query_short_chan_ids.go @@ -0,0 +1,429 @@ +package lnwire + +import ( + "bytes" + "compress/zlib" + "fmt" + "io" + "sort" + "sync" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// ShortChanIDEncoding is an enum-like type that represents exactly how a set +// of short channel ID's is encoded on the wire. The set of encodings allows us +// to take advantage of the structure of a list of short channel ID's to +// achieving a high degree of compression. +type ShortChanIDEncoding uint8 + +const ( + // EncodingSortedPlain signals that the set of short channel ID's is + // encoded using the regular encoding, in a sorted order. + EncodingSortedPlain ShortChanIDEncoding = 0 + + // EncodingSortedZlib signals that the set of short channel ID's is + // encoded by first sorting the set of channel ID's, as then + // compressing them using zlib. + EncodingSortedZlib ShortChanIDEncoding = 1 +) + +const ( + // maxZlibBufSize is the max number of bytes that we'll accept from a + // zlib decoding instance. We do this in order to limit the total + // amount of memory allocated during a decoding instance. + maxZlibBufSize = 67413630 +) + +// ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose +// items were not sorted. +type ErrUnsortedSIDs struct { + prevSID ShortChannelID + curSID ShortChannelID +} + +// Error returns a human-readable description of the error. +func (e ErrUnsortedSIDs) Error() string { + return fmt.Sprintf("current sid: %v isn't greater than last sid: %v", + e.curSID, e.prevSID) +} + +// zlibDecodeMtx is a package level mutex that we'll use in order to ensure +// that we'll only attempt a single zlib decoding instance at a time. This +// allows us to also further bound our memory usage. +var zlibDecodeMtx sync.Mutex + +// ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we +// came across an unknown short channel ID encoding, and therefore were unable +// to continue parsing. +func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { + return fmt.Errorf("unknown short chan id encoding: %v", encoding) +} + +// QueryShortChanIDs is a message that allows the sender to query a set of +// channel announcement and channel update messages that correspond to the set +// of encoded short channel ID's. The encoding of the short channel ID's is +// detailed in the query message ensuring that the receiver knows how to +// properly decode each encode short channel ID which may be encoded using a +// compression format. The receiver should respond with a series of channel +// announcement and channel updates, finally sending a ReplyShortChanIDsEnd +// message. +type QueryShortChanIDs struct { + // ChainHash denotes the target chain that we're querying for the + // channel ID's of. + ChainHash chainhash.Hash + + // EncodingType is a signal to the receiver of the message that + // indicates exactly how the set of short channel ID's that follow have + // been encoded. + EncodingType ShortChanIDEncoding + + // ShortChanIDs is a slice of decoded short channel ID's. + ShortChanIDs []ShortChannelID + + // noSort indicates whether or not to sort the short channel ids before + // writing them out. + // + // NOTE: This should only be used during testing. + noSort bool +} + +// NewQueryShortChanIDs creates a new QueryShortChanIDs message. +func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, + s []ShortChannelID) *QueryShortChanIDs { + + return &QueryShortChanIDs{ + ChainHash: h, + EncodingType: e, + ShortChanIDs: s, + } +} + +// A compile time check to ensure QueryShortChanIDs implements the +// lnwire.Message interface. +var _ Message = (*QueryShortChanIDs)(nil) + +// Decode deserializes a serialized QueryShortChanIDs message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { + err := ReadElements(r, q.ChainHash[:]) + if err != nil { + return err + } + + q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + + return err +} + +// decodeShortChanIDs decodes a set of short channel ID's that have been +// encoded. The first byte of the body details how the short chan ID's were +// encoded. We'll use this type to govern exactly how we go about encoding the +// set of short channel ID's. +func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { + // First, we'll attempt to read the number of bytes in the body of the + // set of encoded short channel ID's. + var numBytesResp uint16 + err := ReadElements(r, &numBytesResp) + if err != nil { + return 0, nil, err + } + + if numBytesResp == 0 { + return 0, nil, nil + } + + queryBody := make([]byte, numBytesResp) + if _, err := io.ReadFull(r, queryBody); err != nil { + return 0, nil, err + } + + // The first byte is the encoding type, so we'll extract that so we can + // continue our parsing. + encodingType := ShortChanIDEncoding(queryBody[0]) + + // Before continuing, we'll snip off the first byte of the query body + // as that was just the encoding type. + queryBody = queryBody[1:] + + // Otherwise, depending on the encoding type, we'll decode the encode + // short channel ID's in a different manner. + switch encodingType { + + // In this encoding, we'll simply read a sort array of encoded short + // channel ID's from the buffer. + case EncodingSortedPlain: + // If after extracting the encoding type, the number of + // remaining bytes is not a whole multiple of the size of an + // encoded short channel ID (8 bytes), then we'll return a + // parsing error. + if len(queryBody)%8 != 0 { + return 0, nil, fmt.Errorf("whole number of short "+ + "chan ID's cannot be encoded in len=%v", + len(queryBody)) + } + + // As each short channel ID is encoded as 8 bytes, we can + // compute the number of bytes encoded based on the size of the + // query body. + numShortChanIDs := len(queryBody) / 8 + if numShortChanIDs == 0 { + return encodingType, nil, nil + } + + // Finally, we'll read out the exact number of short channel + // ID's to conclude our parsing. + shortChanIDs := make([]ShortChannelID, numShortChanIDs) + bodyReader := bytes.NewReader(queryBody) + var lastChanID ShortChannelID + for i := 0; i < numShortChanIDs; i++ { + if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { + return 0, nil, fmt.Errorf("unable to parse "+ + "short chan ID: %v", err) + } + + // We'll ensure that this short chan ID is greater than + // the last one. This is a requirement within the + // encoding, and if violated can aide us in detecting + // malicious payloads. This can only be true starting + // at the second chanID. + cid := shortChanIDs[i] + if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { + return 0, nil, ErrUnsortedSIDs{lastChanID, cid} + } + lastChanID = cid + } + + return encodingType, shortChanIDs, nil + + // In this encoding, we'll use zlib to decode the compressed payload. + // However, we'll pay attention to ensure that we don't open our selves + // up to a memory exhaustion attack. + case EncodingSortedZlib: + // We'll obtain an ultimately release the zlib decode mutex. + // This guards us against allocating too much memory to decode + // each instance from concurrent peers. + zlibDecodeMtx.Lock() + defer zlibDecodeMtx.Unlock() + + // At this point, if there's no body remaining, then only the encoding + // type was specified, meaning that there're no further bytes to be + // parsed. + if len(queryBody) == 0 { + return encodingType, nil, nil + } + + // Before we start to decode, we'll create a limit reader over + // the current reader. This will ensure that we can control how + // much memory we're allocating during the decoding process. + limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ + R: bytes.NewReader(queryBody), + N: maxZlibBufSize, + }) + if err != nil { + return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) + } + + var ( + shortChanIDs []ShortChannelID + lastChanID ShortChannelID + i int + ) + for { + // We'll now attempt to read the next short channel ID + // encoded in the payload. + var cid ShortChannelID + err := ReadElements(limitedDecompressor, &cid) + + switch { + // If we get an EOF error, then that either means we've + // read all that's contained in the buffer, or have hit + // our limit on the number of bytes we'll read. In + // either case, we'll return what we have so far. + case err == io.ErrUnexpectedEOF || err == io.EOF: + return encodingType, shortChanIDs, nil + + // Otherwise, we hit some other sort of error, possibly + // an invalid payload, so we'll exit early with the + // error. + case err != nil: + return 0, nil, fmt.Errorf("unable to "+ + "deflate next short chan "+ + "ID: %v", err) + } + + // We successfully read the next ID, so we'll collect + // that in the set of final ID's to return. + shortChanIDs = append(shortChanIDs, cid) + + // Finally, we'll ensure that this short chan ID is + // greater than the last one. This is a requirement + // within the encoding, and if violated can aide us in + // detecting malicious payloads. This can only be true + // starting at the second chanID. + if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { + return 0, nil, ErrUnsortedSIDs{lastChanID, cid} + } + + lastChanID = cid + i++ + } + + default: + // If we've been sent an encoding type that we don't know of, + // then we'll return a parsing error as we can't continue if + // we're unable to encode them. + return 0, nil, ErrUnknownShortChanIDEncoding(encodingType) + } +} + +// Encode serializes the target QueryShortChanIDs into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { + // First, we'll write out the chain hash. + err := WriteElements(w, q.ChainHash[:]) + if err != nil { + return err + } + + // Base on our encoding type, we'll write out the set of short channel + // ID's. + return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) +} + +// encodeShortChanIDs encodes the passed short channel ID's into the passed +// io.Writer, respecting the specified encoding type. +func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, + shortChanIDs []ShortChannelID, noSort bool) error { + + // For both of the current encoding types, the channel ID's are to be + // sorted in place, so we'll do that now. The sorting is applied unless + // we were specifically requested not to for testing purposes. + if !noSort { + sort.Slice(shortChanIDs, func(i, j int) bool { + return shortChanIDs[i].ToUint64() < + shortChanIDs[j].ToUint64() + }) + } + + switch encodingType { + + // In this encoding, we'll simply write a sorted array of encoded short + // channel ID's from the buffer. + case EncodingSortedPlain: + // First, we'll write out the number of bytes of the query + // body. We add 1 as the response will have the encoding type + // prepended to it. + numBytesBody := uint16(len(shortChanIDs)*8) + 1 + if err := WriteElements(w, numBytesBody); err != nil { + return err + } + + // We'll then write out the encoding that that follows the + // actual encoded short channel ID's. + if err := WriteElements(w, encodingType); err != nil { + return err + } + + // Now that we know they're sorted, we can write out each short + // channel ID to the buffer. + for _, chanID := range shortChanIDs { + if err := WriteElements(w, chanID); err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + return nil + + // For this encoding we'll first write out a serialized version of all + // the channel ID's into a buffer, then zlib encode that. The final + // payload is what we'll write out to the passed io.Writer. + // + // TODO(roasbeef): assumes the caller knows the proper chunk size to + // pass to avoid bin-packing here + case EncodingSortedZlib: + // We'll make a new buffer, then wrap that with a zlib writer + // so we can write directly to the buffer and encode in a + // streaming manner. + var buf bytes.Buffer + zlibWriter := zlib.NewWriter(&buf) + + // If we don't have anything at all to write, then we'll write + // an empty payload so we don't include things like the zlib + // header when the remote party is expecting no actual short + // channel IDs. + var compressedPayload []byte + if len(shortChanIDs) > 0 { + // Next, we'll write out all the channel ID's directly + // into the zlib writer, which will do compressing on + // the fly. + for _, chanID := range shortChanIDs { + err := WriteElements(zlibWriter, chanID) + if err != nil { + return fmt.Errorf("unable to write short chan "+ + "ID: %v", err) + } + } + + // Now that we've written all the elements, we'll + // ensure the compressed stream is written to the + // underlying buffer. + if err := zlibWriter.Close(); err != nil { + return fmt.Errorf("unable to finalize "+ + "compression: %v", err) + } + + compressedPayload = buf.Bytes() + } + + // Now that we have all the items compressed, we can compute + // what the total payload size will be. We add one to account + // for the byte to encode the type. + // + // If we don't have any actual bytes to write, then we'll end + // up emitting one byte for the length, followed by the + // encoding type, and nothing more. The spec isn't 100% clear + // in this area, but we do this as this is what most of the + // other implementations do. + numBytesBody := len(compressedPayload) + 1 + + // Finally, we can write out the number of bytes, the + // compression type, and finally the buffer itself. + if err := WriteElements(w, uint16(numBytesBody)); err != nil { + return err + } + if err := WriteElements(w, encodingType); err != nil { + return err + } + + _, err := w.Write(compressedPayload) + return err + + default: + // If we're trying to encode with an encoding type that we + // don't know of, then we'll return a parsing error as we can't + // continue if we're unable to encode them. + return ErrUnknownShortChanIDEncoding(encodingType) + } +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) MsgType() MessageType { + return MsgQueryShortChanIDs +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// QueryShortChanIDs complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} diff --git a/channeldb/migration/lnwire21/reply_channel_range.go b/channeldb/migration/lnwire21/reply_channel_range.go new file mode 100644 index 00000000..43060602 --- /dev/null +++ b/channeldb/migration/lnwire21/reply_channel_range.go @@ -0,0 +1,90 @@ +package lnwire + +import "io" + +// ReplyChannelRange is the response to the QueryChannelRange message. It +// includes the original query, and the next streaming chunk of encoded short +// channel ID's as the response. We'll also include a byte that indicates if +// this is the last query in the message. +type ReplyChannelRange struct { + // QueryChannelRange is the corresponding query to this response. + QueryChannelRange + + // Complete denotes if this is the conclusion of the set of streaming + // responses to the original query. + Complete uint8 + + // EncodingType is a signal to the receiver of the message that + // indicates exactly how the set of short channel ID's that follow have + // been encoded. + EncodingType ShortChanIDEncoding + + // ShortChanIDs is a slice of decoded short channel ID's. + ShortChanIDs []ShortChannelID + + // noSort indicates whether or not to sort the short channel ids before + // writing them out. + // + // NOTE: This should only be used for testing. + noSort bool +} + +// NewReplyChannelRange creates a new empty ReplyChannelRange message. +func NewReplyChannelRange() *ReplyChannelRange { + return &ReplyChannelRange{} +} + +// A compile time check to ensure ReplyChannelRange implements the +// lnwire.Message interface. +var _ Message = (*ReplyChannelRange)(nil) + +// Decode deserializes a serialized ReplyChannelRange message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { + err := c.QueryChannelRange.Decode(r, pver) + if err != nil { + return err + } + + if err := ReadElements(r, &c.Complete); err != nil { + return err + } + + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) + + return err +} + +// Encode serializes the target ReplyChannelRange into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { + if err := c.QueryChannelRange.Encode(w, pver); err != nil { + return err + } + + if err := WriteElements(w, c.Complete); err != nil { + return err + } + + return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) MsgType() MessageType { + return MsgReplyChannelRange +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// ReplyChannelRange complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { + return MaxMessagePayload +} diff --git a/channeldb/migration/lnwire21/reply_short_chan_ids_end.go b/channeldb/migration/lnwire21/reply_short_chan_ids_end.go new file mode 100644 index 00000000..d77aa0b5 --- /dev/null +++ b/channeldb/migration/lnwire21/reply_short_chan_ids_end.go @@ -0,0 +1,74 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) + +// ReplyShortChanIDsEnd is a message that marks the end of a streaming message +// response to an initial QueryShortChanIDs message. This marks that the +// receiver of the original QueryShortChanIDs for the target chain has either +// sent all adequate responses it knows of, or doesn't know of any short chan +// ID's for the target chain. +type ReplyShortChanIDsEnd struct { + // ChainHash denotes the target chain that we're respond to a short + // chan ID query for. + ChainHash chainhash.Hash + + // Complete will be set to 0 if we don't know of the chain that the + // remote peer sent their query for. Otherwise, we'll set this to 1 in + // order to indicate that we've sent all known responses for the prior + // set of short chan ID's in the corresponding QueryShortChanIDs + // message. + Complete uint8 +} + +// NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message. +func NewReplyShortChanIDsEnd() *ReplyShortChanIDsEnd { + return &ReplyShortChanIDsEnd{} +} + +// A compile time check to ensure ReplyShortChanIDsEnd implements the +// lnwire.Message interface. +var _ Message = (*ReplyShortChanIDsEnd)(nil) + +// Decode deserializes a serialized ReplyShortChanIDsEnd message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + c.ChainHash[:], + &c.Complete, + ) +} + +// Encode serializes the target ReplyShortChanIDsEnd into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChainHash[:], + c.Complete, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) MsgType() MessageType { + return MsgReplyShortChanIDsEnd +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// ReplyShortChanIDsEnd complete message observing the specified protocol +// version. +// +// This is part of the lnwire.Message interface. +func (c *ReplyShortChanIDsEnd) MaxPayloadLength(uint32) uint32 { + // 32 (chain hash) + 1 (complete) + return 33 +} diff --git a/channeldb/migration/lnwire21/revoke_and_ack.go b/channeldb/migration/lnwire21/revoke_and_ack.go new file mode 100644 index 00000000..0cfa2bc2 --- /dev/null +++ b/channeldb/migration/lnwire21/revoke_and_ack.go @@ -0,0 +1,91 @@ +package lnwire + +import ( + "io" + + "github.com/btcsuite/btcd/btcec" +) + +// RevokeAndAck is sent by either side once a CommitSig message has been +// received, and validated. This message serves to revoke the prior commitment +// transaction, which was the most up to date version until a CommitSig message +// referencing the specified ChannelPoint was received. Additionally, this +// message also piggyback's the next revocation hash that Alice should use when +// constructing the Bob's version of the next commitment transaction (which +// would be done before sending a CommitSig message). This piggybacking allows +// Alice to send the next CommitSig message modifying Bob's commitment +// transaction without first asking for a revocation hash initially. +type RevokeAndAck struct { + // ChanID uniquely identifies to which currently active channel this + // RevokeAndAck applies to. + ChanID ChannelID + + // Revocation is the preimage to the revocation hash of the now prior + // commitment transaction. + Revocation [32]byte + + // NextRevocationKey is the next commitment point which should be used + // for the next commitment transaction the remote peer creates for us. + // This, in conjunction with revocation base point will be used to + // create the proper revocation key used within the commitment + // transaction. + NextRevocationKey *btcec.PublicKey +} + +// NewRevokeAndAck creates a new RevokeAndAck message. +func NewRevokeAndAck() *RevokeAndAck { + return &RevokeAndAck{} +} + +// A compile time check to ensure RevokeAndAck implements the lnwire.Message +// interface. +var _ Message = (*RevokeAndAck)(nil) + +// Decode deserializes a serialized RevokeAndAck message stored in the +// passed io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + c.Revocation[:], + &c.NextRevocationKey, + ) +} + +// Encode serializes the target RevokeAndAck into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.Revocation[:], + c.NextRevocationKey, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *RevokeAndAck) MsgType() MessageType { + return MsgRevokeAndAck +} + +// MaxPayloadLength returns the maximum allowed payload size for a RevokeAndAck +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *RevokeAndAck) MaxPayloadLength(uint32) uint32 { + // 32 + 32 + 33 + return 97 +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *RevokeAndAck) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration/lnwire21/short_channel_id.go b/channeldb/migration/lnwire21/short_channel_id.go new file mode 100644 index 00000000..b2b980aa --- /dev/null +++ b/channeldb/migration/lnwire21/short_channel_id.go @@ -0,0 +1,48 @@ +package lnwire + +import ( + "fmt" +) + +// ShortChannelID represents the set of data which is needed to retrieve all +// necessary data to validate the channel existence. +type ShortChannelID struct { + // BlockHeight is the height of the block where funding transaction + // located. + // + // NOTE: This field is limited to 3 bytes. + BlockHeight uint32 + + // TxIndex is a position of funding transaction within a block. + // + // NOTE: This field is limited to 3 bytes. + TxIndex uint32 + + // TxPosition indicating transaction output which pays to the channel. + TxPosition uint16 +} + +// NewShortChanIDFromInt returns a new ShortChannelID which is the decoded +// version of the compact channel ID encoded within the uint64. The format of +// the compact channel ID is as follows: 3 bytes for the block height, 3 bytes +// for the transaction index, and 2 bytes for the output index. +func NewShortChanIDFromInt(chanID uint64) ShortChannelID { + return ShortChannelID{ + BlockHeight: uint32(chanID >> 40), + TxIndex: uint32(chanID>>16) & 0xFFFFFF, + TxPosition: uint16(chanID), + } +} + +// ToUint64 converts the ShortChannelID into a compact format encoded within a +// uint64 (8 bytes). +func (c ShortChannelID) ToUint64() uint64 { + // TODO(roasbeef): explicit error on overflow? + return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) | + (uint64(c.TxPosition))) +} + +// String generates a human-readable representation of the channel ID. +func (c ShortChannelID) String() string { + return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition) +} diff --git a/channeldb/migration/lnwire21/shutdown.go b/channeldb/migration/lnwire21/shutdown.go new file mode 100644 index 00000000..94d10a90 --- /dev/null +++ b/channeldb/migration/lnwire21/shutdown.go @@ -0,0 +1,87 @@ +package lnwire + +import ( + "io" +) + +// Shutdown is sent by either side in order to initiate the cooperative closure +// of a channel. This message is sparse as both sides implicitly have the +// information necessary to construct a transaction that will send the settled +// funds of both parties to the final delivery addresses negotiated during the +// funding workflow. +type Shutdown struct { + // ChannelID serves to identify which channel is to be closed. + ChannelID ChannelID + + // Address is the script to which the channel funds will be paid. + Address DeliveryAddress +} + +// DeliveryAddress is used to communicate the address to which funds from a +// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or +// p2wpkh. +type DeliveryAddress []byte + +// deliveryAddressMaxSize is the maximum expected size in bytes of a +// DeliveryAddress based on the types of scripts we know. +// Following are the known scripts and their sizes in bytes. +// - pay to witness script hash: 34 +// - pay to pubkey hash: 25 +// - pay to script hash: 22 +// - pay to witness pubkey hash: 22. +const deliveryAddressMaxSize = 34 + +// NewShutdown creates a new Shutdown message. +func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { + return &Shutdown{ + ChannelID: cid, + Address: addr, + } +} + +// A compile-time check to ensure Shutdown implements the lnwire.Message +// interface. +var _ Message = (*Shutdown)(nil) + +// Decode deserializes a serialized Shutdown stored in the passed io.Reader +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (s *Shutdown) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, &s.ChannelID, &s.Address) +} + +// Encode serializes the target Shutdown into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (s *Shutdown) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, s.ChannelID, s.Address) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (s *Shutdown) MsgType() MessageType { + return MsgShutdown +} + +// MaxPayloadLength returns the maximum allowed payload size for this message +// observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (s *Shutdown) MaxPayloadLength(pver uint32) uint32 { + var length uint32 + + // ChannelID - 32bytes + length += 32 + + // Len - 2 bytes + length += 2 + + // ScriptPubKey - maximum delivery address size. + length += deliveryAddressMaxSize + + return length +} diff --git a/channeldb/migration/lnwire21/signature.go b/channeldb/migration/lnwire21/signature.go new file mode 100644 index 00000000..13a2f25c --- /dev/null +++ b/channeldb/migration/lnwire21/signature.go @@ -0,0 +1,129 @@ +package lnwire + +import ( + "fmt" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/input" +) + +// Sig is a fixed-sized ECDSA signature. Unlike Bitcoin, we use fixed sized +// signatures on the wire, instead of DER encoded signatures. This type +// provides several methods to convert to/from a regular Bitcoin DER encoded +// signature (raw bytes and *btcec.Signature). +type Sig [64]byte + +// NewSigFromRawSignature returns a Sig from a Bitcoin raw signature encoded in +// the canonical DER encoding. +func NewSigFromRawSignature(sig []byte) (Sig, error) { + var b Sig + + if len(sig) == 0 { + return b, fmt.Errorf("cannot decode empty signature") + } + + // Extract lengths of R and S. The DER representation is laid out as + // 0x30 0x02 r 0x02 s + // which means the length of R is the 4th byte and the length of S + // is the second byte after R ends. 0x02 signifies a length-prefixed, + // zero-padded, big-endian bigint. 0x30 signifies a DER signature. + // See the Serialize() method for btcec.Signature for details. + rLen := sig[3] + sLen := sig[5+rLen] + + // Check to make sure R and S can both fit into their intended buffers. + // We check S first because these code blocks decrement sLen and rLen + // in the case of a 33-byte 0-padded integer returned from Serialize() + // and rLen is used in calculating array indices for S. We can track + // this with additional variables, but it's more efficient to just + // check S first. + if sLen > 32 { + if (sLen > 33) || (sig[6+rLen] != 0x00) { + return b, fmt.Errorf("S is over 32 bytes long " + + "without padding") + } + sLen-- + copy(b[64-sLen:], sig[7+rLen:]) + } else { + copy(b[64-sLen:], sig[6+rLen:]) + } + + // Do the same for R as we did for S + if rLen > 32 { + if (rLen > 33) || (sig[4] != 0x00) { + return b, fmt.Errorf("R is over 32 bytes long " + + "without padding") + } + rLen-- + copy(b[32-rLen:], sig[5:5+rLen]) + } else { + copy(b[32-rLen:], sig[4:4+rLen]) + } + + return b, nil +} + +// NewSigFromSignature creates a new signature as used on the wire, from an +// existing btcec.Signature. +func NewSigFromSignature(e input.Signature) (Sig, error) { + if e == nil { + return Sig{}, fmt.Errorf("cannot decode empty signature") + } + + // Serialize the signature with all the checks that entails. + return NewSigFromRawSignature(e.Serialize()) +} + +// ToSignature converts the fixed-sized signature to a btcec.Signature objects +// which can be used for signature validation checks. +func (b *Sig) ToSignature() (*btcec.Signature, error) { + // Parse the signature with strict checks. + sigBytes := b.ToSignatureBytes() + sig, err := btcec.ParseDERSignature(sigBytes, btcec.S256()) + if err != nil { + return nil, err + } + + return sig, nil +} + +// ToSignatureBytes serializes the target fixed-sized signature into the raw +// bytes of a DER encoding. +func (b *Sig) ToSignatureBytes() []byte { + // Extract canonically-padded bigint representations from buffer + r := extractCanonicalPadding(b[0:32]) + s := extractCanonicalPadding(b[32:64]) + rLen := uint8(len(r)) + sLen := uint8(len(s)) + + // Create a canonical serialized signature. DER format is: + // 0x30 0x02 r 0x02 s + sigBytes := make([]byte, 6+rLen+sLen) + sigBytes[0] = 0x30 // DER signature magic value + sigBytes[1] = 4 + rLen + sLen // Length of rest of signature + sigBytes[2] = 0x02 // Big integer magic value + sigBytes[3] = rLen // Length of R + sigBytes[rLen+4] = 0x02 // Big integer magic value + sigBytes[rLen+5] = sLen // Length of S + copy(sigBytes[4:], r) // Copy R + copy(sigBytes[rLen+6:], s) // Copy S + + return sigBytes +} + +// extractCanonicalPadding is a utility function to extract the canonical +// padding of a big-endian integer from the wire encoding (a 0-padded +// big-endian integer) such that it passes btcec.canonicalPadding test. +func extractCanonicalPadding(b []byte) []byte { + for i := 0; i < len(b); i++ { + // Found first non-zero byte. + if b[i] > 0 { + // If the MSB is set, we need zero padding. + if b[i]&0x80 == 0x80 { + return append([]byte{0x00}, b[i:]...) + } + return b[i:] + } + } + return []byte{0x00} +} diff --git a/channeldb/migration/lnwire21/update_add_htlc.go b/channeldb/migration/lnwire21/update_add_htlc.go new file mode 100644 index 00000000..028c6320 --- /dev/null +++ b/channeldb/migration/lnwire21/update_add_htlc.go @@ -0,0 +1,119 @@ +package lnwire + +import ( + "io" +) + +// OnionPacketSize is the size of the serialized Sphinx onion packet included +// in each UpdateAddHTLC message. The breakdown of the onion packet is as +// follows: 1-byte version, 33-byte ephemeral public key (for ECDH), 1300-bytes +// of per-hop data, and a 32-byte HMAC over the entire packet. +const OnionPacketSize = 1366 + +// UpdateAddHTLC is the message sent by Alice to Bob when she wishes to add an +// HTLC to his remote commitment transaction. In addition to information +// detailing the value, the ID, expiry, and the onion blob is also included +// which allows Bob to derive the next hop in the route. The HTLC added by this +// message is to be added to the remote node's "pending" HTLC's. A subsequent +// CommitSig message will move the pending HTLC to the newly created commitment +// transaction, marking them as "staged". +type UpdateAddHTLC struct { + // ChanID is the particular active channel that this UpdateAddHTLC is + // bound to. + ChanID ChannelID + + // ID is the identification server for this HTLC. This value is + // explicitly included as it allows nodes to survive single-sided + // restarts. The ID value for this sides starts at zero, and increases + // with each offered HTLC. + ID uint64 + + // Amount is the amount of millisatoshis this HTLC is worth. + Amount MilliSatoshi + + // PaymentHash is the payment hash to be included in the HTLC this + // request creates. The pre-image to this HTLC must be revealed by the + // upstream peer in order to fully settle the HTLC. + PaymentHash [32]byte + + // Expiry is the number of blocks after which this HTLC should expire. + // It is the receiver's duty to ensure that the outgoing HTLC has a + // sufficient expiry value to allow her to redeem the incoming HTLC. + Expiry uint32 + + // OnionBlob is the raw serialized mix header used to route an HTLC in + // a privacy-preserving manner. The mix header is defined currently to + // be parsed as a 4-tuple: (groupElement, routingInfo, headerMAC, + // body). First the receiving node should use the groupElement, and + // its current onion key to derive a shared secret with the source. + // Once the shared secret has been derived, the headerMAC should be + // checked FIRST. Note that the MAC only covers the routingInfo field. + // If the MAC matches, and the shared secret is fresh, then the node + // should strip off a layer of encryption, exposing the next hop to be + // used in the subsequent UpdateAddHTLC message. + OnionBlob [OnionPacketSize]byte +} + +// NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. +func NewUpdateAddHTLC() *UpdateAddHTLC { + return &UpdateAddHTLC{} +} + +// A compile time check to ensure UpdateAddHTLC implements the lnwire.Message +// interface. +var _ Message = (*UpdateAddHTLC)(nil) + +// Decode deserializes a serialized UpdateAddHTLC message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.ID, + &c.Amount, + c.PaymentHash[:], + &c.Expiry, + c.OnionBlob[:], + ) +} + +// Encode serializes the target UpdateAddHTLC into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.ID, + c.Amount, + c.PaymentHash[:], + c.Expiry, + c.OnionBlob[:], + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *UpdateAddHTLC) MsgType() MessageType { + return MsgUpdateAddHTLC +} + +// MaxPayloadLength returns the maximum allowed payload size for an UpdateAddHTLC +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateAddHTLC) MaxPayloadLength(uint32) uint32 { + // 1450 + return 32 + 8 + 4 + 8 + 32 + 1366 +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *UpdateAddHTLC) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration/lnwire21/update_fail_htlc.go b/channeldb/migration/lnwire21/update_fail_htlc.go new file mode 100644 index 00000000..194f2ecd --- /dev/null +++ b/channeldb/migration/lnwire21/update_fail_htlc.go @@ -0,0 +1,95 @@ +package lnwire + +import ( + "io" +) + +// OpaqueReason is an opaque encrypted byte slice that encodes the exact +// failure reason and additional some supplemental data. The contents of this +// slice can only be decrypted by the sender of the original HTLC. +type OpaqueReason []byte + +// UpdateFailHTLC is sent by Alice to Bob in order to remove a previously added +// HTLC. Upon receipt of an UpdateFailHTLC the HTLC should be removed from the +// next commitment transaction, with the UpdateFailHTLC propagated backwards in +// the route to fully undo the HTLC. +type UpdateFailHTLC struct { + // ChanIDPoint is the particular active channel that this + // UpdateFailHTLC is bound to. + ChanID ChannelID + + // ID references which HTLC on the remote node's commitment transaction + // has timed out. + ID uint64 + + // Reason is an onion-encrypted blob that details why the HTLC was + // failed. This blob is only fully decryptable by the initiator of the + // HTLC message. + Reason OpaqueReason +} + +// A compile time check to ensure UpdateFailHTLC implements the lnwire.Message +// interface. +var _ Message = (*UpdateFailHTLC)(nil) + +// Decode deserializes a serialized UpdateFailHTLC message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.ID, + &c.Reason, + ) +} + +// Encode serializes the target UpdateFailHTLC into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.ID, + c.Reason, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailHTLC) MsgType() MessageType { + return MsgUpdateFailHTLC +} + +// MaxPayloadLength returns the maximum allowed payload size for an UpdateFailHTLC +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailHTLC) MaxPayloadLength(uint32) uint32 { + var length uint32 + + // Length of the ChanID + length += 32 + + // Length of the ID + length += 8 + + // Length of the length opaque reason + length += 2 + + // Length of the Reason + length += 292 + + return length +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *UpdateFailHTLC) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration/lnwire21/update_fail_malformed_htlc.go b/channeldb/migration/lnwire21/update_fail_malformed_htlc.go new file mode 100644 index 00000000..39d4b870 --- /dev/null +++ b/channeldb/migration/lnwire21/update_fail_malformed_htlc.go @@ -0,0 +1,83 @@ +package lnwire + +import ( + "crypto/sha256" + "io" +) + +// UpdateFailMalformedHTLC is sent by either the payment forwarder or by +// payment receiver to the payment sender in order to notify it that the onion +// blob can't be parsed. For that reason we send this message instead of +// obfuscate the onion failure. +type UpdateFailMalformedHTLC struct { + // ChanID is the particular active channel that this + // UpdateFailMalformedHTLC is bound to. + ChanID ChannelID + + // ID references which HTLC on the remote node's commitment transaction + // has timed out. + ID uint64 + + // ShaOnionBlob hash of the onion blob on which can't be parsed by the + // node in the payment path. + ShaOnionBlob [sha256.Size]byte + + // FailureCode the exact reason why onion blob haven't been parsed. + FailureCode FailCode +} + +// A compile time check to ensure UpdateFailMalformedHTLC implements the +// lnwire.Message interface. +var _ Message = (*UpdateFailMalformedHTLC)(nil) + +// Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.ID, + c.ShaOnionBlob[:], + &c.FailureCode, + ) +} + +// Encode serializes the target UpdateFailMalformedHTLC into the passed +// io.Writer observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.ID, + c.ShaOnionBlob[:], + c.FailureCode, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) MsgType() MessageType { + return MsgUpdateFailMalformedHTLC +} + +// MaxPayloadLength returns the maximum allowed payload size for a +// UpdateFailMalformedHTLC complete message observing the specified protocol +// version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) MaxPayloadLength(uint32) uint32 { + // 32 + 8 + 32 + 2 + return 74 +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *UpdateFailMalformedHTLC) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration/lnwire21/update_fee.go b/channeldb/migration/lnwire21/update_fee.go new file mode 100644 index 00000000..2d27c377 --- /dev/null +++ b/channeldb/migration/lnwire21/update_fee.go @@ -0,0 +1,78 @@ +package lnwire + +import ( + "io" +) + +// UpdateFee is the message the channel initiator sends to the other peer if +// the channel commitment fee needs to be updated. +type UpdateFee struct { + // ChanID is the channel that this UpdateFee is meant for. + ChanID ChannelID + + // FeePerKw is the fee-per-kw on commit transactions that the sender of + // this message wants to use for this channel. + // + // TODO(halseth): make SatPerKWeight when fee estimation is moved to + // own package. Currently this will cause an import cycle. + FeePerKw uint32 +} + +// NewUpdateFee creates a new UpdateFee message. +func NewUpdateFee(chanID ChannelID, feePerKw uint32) *UpdateFee { + return &UpdateFee{ + ChanID: chanID, + FeePerKw: feePerKw, + } +} + +// A compile time check to ensure UpdateFee implements the lnwire.Message +// interface. +var _ Message = (*UpdateFee)(nil) + +// Decode deserializes a serialized UpdateFee message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.FeePerKw, + ) +} + +// Encode serializes the target UpdateFee into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFee) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.FeePerKw, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFee) MsgType() MessageType { + return MsgUpdateFee +} + +// MaxPayloadLength returns the maximum allowed payload size for an UpdateFee +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFee) MaxPayloadLength(uint32) uint32 { + // 32 + 4 + return 36 +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *UpdateFee) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration/lnwire21/update_fulfill_htlc.go b/channeldb/migration/lnwire21/update_fulfill_htlc.go new file mode 100644 index 00000000..6c0e6339 --- /dev/null +++ b/channeldb/migration/lnwire21/update_fulfill_htlc.go @@ -0,0 +1,88 @@ +package lnwire + +import ( + "io" +) + +// UpdateFulfillHTLC is sent by Alice to Bob when she wishes to settle a +// particular HTLC referenced by its HTLCKey within a specific active channel +// referenced by ChannelPoint. A subsequent CommitSig message will be sent by +// Alice to "lock-in" the removal of the specified HTLC, possible containing a +// batch signature covering several settled HTLC's. +type UpdateFulfillHTLC struct { + // ChanID references an active channel which holds the HTLC to be + // settled. + ChanID ChannelID + + // ID denotes the exact HTLC stage within the receiving node's + // commitment transaction to be removed. + ID uint64 + + // PaymentPreimage is the R-value preimage required to fully settle an + // HTLC. + PaymentPreimage [32]byte +} + +// NewUpdateFulfillHTLC returns a new empty UpdateFulfillHTLC. +func NewUpdateFulfillHTLC(chanID ChannelID, id uint64, + preimage [32]byte) *UpdateFulfillHTLC { + + return &UpdateFulfillHTLC{ + ChanID: chanID, + ID: id, + PaymentPreimage: preimage, + } +} + +// A compile time check to ensure UpdateFulfillHTLC implements the lnwire.Message +// interface. +var _ Message = (*UpdateFulfillHTLC)(nil) + +// Decode deserializes a serialized UpdateFulfillHTLC message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { + return ReadElements(r, + &c.ChanID, + &c.ID, + c.PaymentPreimage[:], + ) +} + +// Encode serializes the target UpdateFulfillHTLC into the passed io.Writer +// observing the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error { + return WriteElements(w, + c.ChanID, + c.ID, + c.PaymentPreimage[:], + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFulfillHTLC) MsgType() MessageType { + return MsgUpdateFulfillHTLC +} + +// MaxPayloadLength returns the maximum allowed payload size for an UpdateFulfillHTLC +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFulfillHTLC) MaxPayloadLength(uint32) uint32 { + // 32 + 8 + 32 + return 72 +} + +// TargetChanID returns the channel id of the link for which this message is +// intended. +// +// NOTE: Part of peer.LinkUpdater interface. +func (c *UpdateFulfillHTLC) TargetChanID() ChannelID { + return c.ChanID +} diff --git a/channeldb/migration12/invoices.go b/channeldb/migration12/invoices.go index 0b83fe1f..6b83518f 100644 --- a/channeldb/migration12/invoices.go +++ b/channeldb/migration12/invoices.go @@ -7,8 +7,8 @@ import ( "time" "github.com/btcsuite/btcd/wire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/channeldb/migration12/migration.go b/channeldb/migration12/migration.go index 66f988de..2ec9e354 100644 --- a/channeldb/migration12/migration.go +++ b/channeldb/migration12/migration.go @@ -4,7 +4,7 @@ import ( "bytes" "github.com/lightningnetwork/lnd/channeldb/kvdb" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) var emptyFeatures = lnwire.NewFeatureVector(nil, nil) diff --git a/channeldb/migration21/common/enclosed_types.go b/channeldb/migration21/common/enclosed_types.go new file mode 100644 index 00000000..86dc9f3b --- /dev/null +++ b/channeldb/migration21/common/enclosed_types.go @@ -0,0 +1,658 @@ +package common + +import ( + "bytes" + "encoding/binary" + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/keychain" +) + +// CircuitKey is used by a channel to uniquely identify the HTLCs it receives +// from the switch, and is used to purge our in-memory state of HTLCs that have +// already been processed by a link. Two list of CircuitKeys are included in +// each CommitDiff to allow a link to determine which in-memory htlcs directed +// the opening and closing of circuits in the switch's circuit map. +type CircuitKey struct { + // ChanID is the short chanid indicating the HTLC's origin. + // + // NOTE: It is fine for this value to be blank, as this indicates a + // locally-sourced payment. + ChanID lnwire.ShortChannelID + + // HtlcID is the unique htlc index predominately assigned by links, + // though can also be assigned by switch in the case of locally-sourced + // payments. + HtlcID uint64 +} + +// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are +// contained within ChannelDeltas which encode the current state of the +// commitment between state updates. +// +// TODO(roasbeef): save space by using smaller ints at tail end? +type HTLC struct { + // Signature is the signature for the second level covenant transaction + // for this HTLC. The second level transaction is a timeout tx in the + // case that this is an outgoing HTLC, and a success tx in the case + // that this is an incoming HTLC. + // + // TODO(roasbeef): make [64]byte instead? + Signature []byte + + // RHash is the payment hash of the HTLC. + RHash [32]byte + + // Amt is the amount of milli-satoshis this HTLC escrows. + Amt lnwire.MilliSatoshi + + // RefundTimeout is the absolute timeout on the HTLC that the sender + // must wait before reclaiming the funds in limbo. + RefundTimeout uint32 + + // OutputIndex is the output index for this particular HTLC output + // within the commitment transaction. + OutputIndex int32 + + // Incoming denotes whether we're the receiver or the sender of this + // HTLC. + Incoming bool + + // OnionBlob is an opaque blob which is used to complete multi-hop + // routing. + OnionBlob []byte + + // HtlcIndex is the HTLC counter index of this active, outstanding + // HTLC. This differs from the LogIndex, as the HtlcIndex is only + // incremented for each offered HTLC, while they LogIndex is + // incremented for each update (includes settle+fail). + HtlcIndex uint64 + + // LogIndex is the cumulative log index of this HTLC. This differs + // from the HtlcIndex as this will be incremented for each new log + // update added. + LogIndex uint64 +} + +// ChannelCommitment is a snapshot of the commitment state at a particular +// point in the commitment chain. With each state transition, a snapshot of the +// current state along with all non-settled HTLCs are recorded. These snapshots +// detail the state of the _remote_ party's commitment at a particular state +// number. For ourselves (the local node) we ONLY store our most recent +// (unrevoked) state for safety purposes. +type ChannelCommitment struct { + // CommitHeight is the update number that this ChannelDelta represents + // the total number of commitment updates to this point. This can be + // viewed as sort of a "commitment height" as this number is + // monotonically increasing. + CommitHeight uint64 + + // LocalLogIndex is the cumulative log index index of the local node at + // this point in the commitment chain. This value will be incremented + // for each _update_ added to the local update log. + LocalLogIndex uint64 + + // LocalHtlcIndex is the current local running HTLC index. This value + // will be incremented for each outgoing HTLC the local node offers. + LocalHtlcIndex uint64 + + // RemoteLogIndex is the cumulative log index index of the remote node + // at this point in the commitment chain. This value will be + // incremented for each _update_ added to the remote update log. + RemoteLogIndex uint64 + + // RemoteHtlcIndex is the current remote running HTLC index. This value + // will be incremented for each outgoing HTLC the remote node offers. + RemoteHtlcIndex uint64 + + // LocalBalance is the current available settled balance within the + // channel directly spendable by us. + // + // NOTE: This is the balance *after* subtracting any commitment fee, + // AND anchor output values. + LocalBalance lnwire.MilliSatoshi + + // RemoteBalance is the current available settled balance within the + // channel directly spendable by the remote node. + // + // NOTE: This is the balance *after* subtracting any commitment fee, + // AND anchor output values. + RemoteBalance lnwire.MilliSatoshi + + // CommitFee is the amount calculated to be paid in fees for the + // current set of commitment transactions. The fee amount is persisted + // with the channel in order to allow the fee amount to be removed and + // recalculated with each channel state update, including updates that + // happen after a system restart. + CommitFee btcutil.Amount + + // FeePerKw is the min satoshis/kilo-weight that should be paid within + // the commitment transaction for the entire duration of the channel's + // lifetime. This field may be updated during normal operation of the + // channel as on-chain conditions change. + // + // TODO(halseth): make this SatPerKWeight. Cannot be done atm because + // this will cause the import cycle lnwallet<->channeldb. Fee + // estimation stuff should be in its own package. + FeePerKw btcutil.Amount + + // CommitTx is the latest version of the commitment state, broadcast + // able by us. + CommitTx *wire.MsgTx + + // CommitSig is one half of the signature required to fully complete + // the script for the commitment transaction above. This is the + // signature signed by the remote party for our version of the + // commitment transactions. + CommitSig []byte + + // Htlcs is the set of HTLC's that are pending at this particular + // commitment height. + Htlcs []HTLC + + // TODO(roasbeef): pending commit pointer? + // * lets just walk through +} + +// LogUpdate represents a pending update to the remote commitment chain. The +// log update may be an add, fail, or settle entry. We maintain this data in +// order to be able to properly retransmit our proposed +// state if necessary. +type LogUpdate struct { + // LogIndex is the log index of this proposed commitment update entry. + LogIndex uint64 + + // UpdateMsg is the update message that was included within the our + // local update log. The LogIndex value denotes the log index of this + // update which will be used when restoring our local update log if + // we're left with a dangling update on restart. + UpdateMsg lnwire.Message +} + +// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID +// is assumed to be that of the packager. +type AddRef struct { + // Height is the remote commitment height that locked in the Add. + Height uint64 + + // Index is the index of the Add within the fwd pkg's Adds. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A +// channel does not remove its own Settle/Fail htlcs, so the source is provided +// to locate a db bucket belonging to another channel. +type SettleFailRef struct { + // Source identifies the outgoing link that locked in the settle or + // fail. This is then used by the *incoming* link to find the settle + // fail in another link's forwarding packages. + Source lnwire.ShortChannelID + + // Height is the remote commitment height that locked in this + // Settle/Fail. + Height uint64 + + // Index is the index of the Add with the fwd pkg's SettleFails. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// CommitDiff represents the delta needed to apply the state transition between +// two subsequent commitment states. Given state N and state N+1, one is able +// to apply the set of messages contained within the CommitDiff to N to arrive +// at state N+1. Each time a new commitment is extended, we'll write a new +// commitment (along with the full commitment state) to disk so we can +// re-transmit the state in the case of a connection loss or message drop. +type CommitDiff struct { + // ChannelCommitment is the full commitment state that one would arrive + // at by applying the set of messages contained in the UpdateDiff to + // the prior accepted commitment. + Commitment ChannelCommitment + + // LogUpdates is the set of messages sent prior to the commitment state + // transition in question. Upon reconnection, if we detect that they + // don't have the commitment, then we re-send this along with the + // proper signature. + LogUpdates []LogUpdate + + // CommitSig is the exact CommitSig message that should be sent after + // the set of LogUpdates above has been retransmitted. The signatures + // within this message should properly cover the new commitment state + // and also the HTLC's within the new commitment state. + CommitSig *lnwire.CommitSig + + // OpenedCircuitKeys is a set of unique identifiers for any downstream + // Add packets included in this commitment txn. After a restart, this + // set of htlcs is acked from the link's incoming mailbox to ensure + // there isn't an attempt to re-add them to this commitment txn. + OpenedCircuitKeys []CircuitKey + + // ClosedCircuitKeys records the unique identifiers for any settle/fail + // packets that were resolved by this commitment txn. After a restart, + // this is used to ensure those circuits are removed from the circuit + // map, and the downstream packets in the link's mailbox are removed. + ClosedCircuitKeys []CircuitKey + + // AddAcks specifies the locations (commit height, pkg index) of any + // Adds that were failed/settled in this commit diff. This will ack + // entries in *this* channel's forwarding packages. + // + // NOTE: This value is not serialized, it is used to atomically mark the + // resolution of adds, such that they will not be reprocessed after a + // restart. + AddAcks []AddRef + + // SettleFailAcks specifies the locations (chan id, commit height, pkg + // index) of any Settles or Fails that were locked into this commit + // diff, and originate from *another* channel, i.e. the outgoing link. + // + // NOTE: This value is not serialized, it is used to atomically acks + // settles and fails from the forwarding packages of other channels, + // such that they will not be reforwarded internally after a restart. + SettleFailAcks []SettleFailRef +} + +// NetworkResult is the raw result received from the network after a payment +// attempt has been made. Since the switch doesn't always have the necessary +// data to decode the raw message, we store it together with some meta data, +// and decode it when the router query for the final result. +type NetworkResult struct { + // Msg is the received result. This should be of type UpdateFulfillHTLC + // or UpdateFailHTLC. + Msg lnwire.Message + + // unencrypted indicates whether the failure encoded in the message is + // unencrypted, and hence doesn't need to be decrypted. + Unencrypted bool + + // IsResolution indicates whether this is a resolution message, in + // which the failure reason might not be included. + IsResolution bool +} + +// ClosureType is an enum like structure that details exactly _how_ a channel +// was closed. Three closure types are currently possible: none, cooperative, +// local force close, remote force close, and (remote) breach. +type ClosureType uint8 + +// ChannelConstraints represents a set of constraints meant to allow a node to +// limit their exposure, enact flow control and ensure that all HTLCs are +// economically relevant. This struct will be mirrored for both sides of the +// channel, as each side will enforce various constraints that MUST be adhered +// to for the life time of the channel. The parameters for each of these +// constraints are static for the duration of the channel, meaning the channel +// must be torn down for them to change. +type ChannelConstraints struct { + // DustLimit is the threshold (in satoshis) below which any outputs + // should be trimmed. When an output is trimmed, it isn't materialized + // as an actual output, but is instead burned to miner's fees. + DustLimit btcutil.Amount + + // ChanReserve is an absolute reservation on the channel for the + // owner of this set of constraints. This means that the current + // settled balance for this node CANNOT dip below the reservation + // amount. This acts as a defense against costless attacks when + // either side no longer has any skin in the game. + ChanReserve btcutil.Amount + + // MaxPendingAmount is the maximum pending HTLC value that the + // owner of these constraints can offer the remote node at a + // particular time. + MaxPendingAmount lnwire.MilliSatoshi + + // MinHTLC is the minimum HTLC value that the owner of these + // constraints can offer the remote node. If any HTLCs below this + // amount are offered, then the HTLC will be rejected. This, in + // tandem with the dust limit allows a node to regulate the + // smallest HTLC that it deems economically relevant. + MinHTLC lnwire.MilliSatoshi + + // MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of + // this set of constraints can offer the remote node. This allows each + // node to limit their over all exposure to HTLCs that may need to be + // acted upon in the case of a unilateral channel closure or a contract + // breach. + MaxAcceptedHtlcs uint16 + + // CsvDelay is the relative time lock delay expressed in blocks. Any + // settled outputs that pay to the owner of this channel configuration + // MUST ensure that the delay branch uses this value as the relative + // time lock. Similarly, any HTLC's offered by this node should use + // this value as well. + CsvDelay uint16 +} + +// ChannelConfig is a struct that houses the various configuration opens for +// channels. Each side maintains an instance of this configuration file as it +// governs: how the funding and commitment transaction to be created, the +// nature of HTLC's allotted, the keys to be used for delivery, and relative +// time lock parameters. +type ChannelConfig struct { + // ChannelConstraints is the set of constraints that must be upheld for + // the duration of the channel for the owner of this channel + // configuration. Constraints govern a number of flow control related + // parameters, also including the smallest HTLC that will be accepted + // by a participant. + ChannelConstraints + + // MultiSigKey is the key to be used within the 2-of-2 output script + // for the owner of this channel config. + MultiSigKey keychain.KeyDescriptor + + // RevocationBasePoint is the base public key to be used when deriving + // revocation keys for the remote node's commitment transaction. This + // will be combined along with a per commitment secret to derive a + // unique revocation key for each state. + RevocationBasePoint keychain.KeyDescriptor + + // PaymentBasePoint is the base public key to be used when deriving + // the key used within the non-delayed pay-to-self output on the + // commitment transaction for a node. This will be combined with a + // tweak derived from the per-commitment point to ensure unique keys + // for each commitment transaction. + PaymentBasePoint keychain.KeyDescriptor + + // DelayBasePoint is the base public key to be used when deriving the + // key used within the delayed pay-to-self output on the commitment + // transaction for a node. This will be combined with a tweak derived + // from the per-commitment point to ensure unique keys for each + // commitment transaction. + DelayBasePoint keychain.KeyDescriptor + + // HtlcBasePoint is the base public key to be used when deriving the + // local HTLC key. The derived key (combined with the tweak derived + // from the per-commitment point) is used within the "to self" clause + // within any HTLC output scripts. + HtlcBasePoint keychain.KeyDescriptor +} + +// ChannelCloseSummary contains the final state of a channel at the point it +// was closed. Once a channel is closed, all the information pertaining to that +// channel within the openChannelBucket is deleted, and a compact summary is +// put in place instead. +type ChannelCloseSummary struct { + // ChanPoint is the outpoint for this channel's funding transaction, + // and is used as a unique identifier for the channel. + ChanPoint wire.OutPoint + + // ShortChanID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChanID lnwire.ShortChannelID + + // ChainHash is the hash of the genesis block that this channel resides + // within. + ChainHash chainhash.Hash + + // ClosingTXID is the txid of the transaction which ultimately closed + // this channel. + ClosingTXID chainhash.Hash + + // RemotePub is the public key of the remote peer that we formerly had + // a channel with. + RemotePub *btcec.PublicKey + + // Capacity was the total capacity of the channel. + Capacity btcutil.Amount + + // CloseHeight is the height at which the funding transaction was + // spent. + CloseHeight uint32 + + // SettledBalance is our total balance settled balance at the time of + // channel closure. This _does not_ include the sum of any outputs that + // have been time-locked as a result of the unilateral channel closure. + SettledBalance btcutil.Amount + + // TimeLockedBalance is the sum of all the time-locked outputs at the + // time of channel closure. If we triggered the force closure of this + // channel, then this value will be non-zero if our settled output is + // above the dust limit. If we were on the receiving side of a channel + // force closure, then this value will be non-zero if we had any + // outstanding outgoing HTLC's at the time of channel closure. + TimeLockedBalance btcutil.Amount + + // CloseType details exactly _how_ the channel was closed. Five closure + // types are possible: cooperative, local force, remote force, breach + // and funding canceled. + CloseType ClosureType + + // IsPending indicates whether this channel is in the 'pending close' + // state, which means the channel closing transaction has been + // confirmed, but not yet been fully resolved. In the case of a channel + // that has been cooperatively closed, it will go straight into the + // fully resolved state as soon as the closing transaction has been + // confirmed. However, for channels that have been force closed, they'll + // stay marked as "pending" until _all_ the pending funds have been + // swept. + IsPending bool + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this is the derived public key, + // we don't yet have the private key so we aren't yet able to verify + // that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // LocalChanCfg is the channel configuration for the local node. + LocalChanConfig ChannelConfig + + // LastChanSyncMsg is the ChannelReestablish message for this channel + // for the state at the point where it was closed. + LastChanSyncMsg *lnwire.ChannelReestablish +} + +// FwdState is an enum used to describe the lifecycle of a FwdPkg. +type FwdState byte + +const ( + // FwdStateLockedIn is the starting state for all forwarding packages. + // Packages in this state have not yet committed to the exact set of + // Adds to forward to the switch. + FwdStateLockedIn FwdState = iota + + // FwdStateProcessed marks the state in which all Adds have been + // locally processed and the forwarding decision to the switch has been + // persisted. + FwdStateProcessed + + // FwdStateCompleted signals that all Adds have been acked, and that all + // settles and fails have been delivered to their sources. Packages in + // this state can be removed permanently. + FwdStateCompleted +) + +// PkgFilter is used to compactly represent a particular subset of the Adds in a +// forwarding package. Each filter is represented as a simple, statically-sized +// bitvector, where the elements are intended to be the indices of the Adds as +// they are written in the FwdPkg. +type PkgFilter struct { + count uint16 + filter []byte +} + +// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. +func NewPkgFilter(count uint16) *PkgFilter { + // We add 7 to ensure that the integer division yields properly rounded + // values. + filterLen := (count + 7) / 8 + + return &PkgFilter{ + count: count, + filter: make([]byte, filterLen), + } +} + +// Count returns the number of elements represented by this PkgFilter. +func (f *PkgFilter) Count() uint16 { + return f.count +} + +// Set marks the `i`-th element as included by this filter. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Set(i uint16) { + byt := i / 8 + bit := i % 8 + + // Set the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + f.filter[byt] |= byte(1 << (7 - bit)) +} + +// Contains queries the filter for membership of index `i`. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Contains(i uint16) bool { + byt := i / 8 + bit := i % 8 + + // Read the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + return f.filter[byt]&(1<<(7-bit)) != 0 +} + +// Equal checks two PkgFilters for equality. +func (f *PkgFilter) Equal(f2 *PkgFilter) bool { + if f == f2 { + return true + } + if f.count != f2.count { + return false + } + + return bytes.Equal(f.filter, f2.filter) +} + +// IsFull returns true if every element in the filter has been Set, and false +// otherwise. +func (f *PkgFilter) IsFull() bool { + // Batch validate bytes that are fully used. + for i := uint16(0); i < f.count/8; i++ { + if f.filter[i] != 0xFF { + return false + } + } + + // If the count is not a multiple of 8, check that the filter contains + // all remaining bits. + rem := f.count % 8 + for idx := f.count - rem; idx < f.count; idx++ { + if !f.Contains(idx) { + return false + } + } + + return true +} + +// Size returns number of bytes produced when the PkgFilter is serialized. +func (f *PkgFilter) Size() uint16 { + // 2 bytes for uint16 `count`, then round up number of bytes required to + // represent `count` bits. + return 2 + (f.count+7)/8 +} + +// Encode writes the filter to the provided io.Writer. +func (f *PkgFilter) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, f.count); err != nil { + return err + } + + _, err := w.Write(f.filter) + + return err +} + +// Decode reads the filter from the provided io.Reader. +func (f *PkgFilter) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { + return err + } + + f.filter = make([]byte, f.Size()-2) + _, err := io.ReadFull(r, f.filter) + + return err +} + +// FwdPkg records all adds, settles, and fails that were locked in as a result +// of the remote peer sending us a revocation. Each package is identified by +// the short chanid and remote commitment height corresponding to the revocation +// that locked in the HTLCs. For everything except a locally initiated payment, +// settles and fails in a forwarding package must have a corresponding Add in +// another package, and can be removed individually once the source link has +// received the fail/settle. +// +// Adds cannot be removed, as we need to present the same batch of Adds to +// properly handle replay protection. Instead, we use a PkgFilter to mark that +// we have finished processing a particular Add. A FwdPkg should only be deleted +// after the AckFilter is full and all settles and fails have been persistently +// removed. +type FwdPkg struct { + // Source identifies the channel that wrote this forwarding package. + Source lnwire.ShortChannelID + + // Height is the height of the remote commitment chain that locked in + // this forwarding package. + Height uint64 + + // State signals the persistent condition of the package and directs how + // to reprocess the package in the event of failures. + State FwdState + + // Adds contains all add messages which need to be processed and + // forwarded to the switch. Adds does not change over the life of a + // forwarding package. + Adds []LogUpdate + + // FwdFilter is a filter containing the indices of all Adds that were + // forwarded to the switch. + FwdFilter *PkgFilter + + // AckFilter is a filter containing the indices of all Adds for which + // the source has received a settle or fail and is reflected in the next + // commitment txn. A package should not be removed until IsFull() + // returns true. + AckFilter *PkgFilter + + // SettleFails contains all settle and fail messages that should be + // forwarded to the switch. + SettleFails []LogUpdate + + // SettleFailFilter is a filter containing the indices of all Settle or + // Fails originating in this package that have been received and locked + // into the incoming link's commitment state. + SettleFailFilter *PkgFilter +} + +// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This +// should be used to create a package at the time we receive a revocation. +func NewFwdPkg(source lnwire.ShortChannelID, height uint64, + addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { + + nAddUpdates := uint16(len(addUpdates)) + nSettleFailUpdates := uint16(len(settleFailUpdates)) + + return &FwdPkg{ + Source: source, + Height: height, + State: FwdStateLockedIn, + Adds: addUpdates, + FwdFilter: NewPkgFilter(nAddUpdates), + AckFilter: NewPkgFilter(nAddUpdates), + SettleFails: settleFailUpdates, + SettleFailFilter: NewPkgFilter(nSettleFailUpdates), + } +} diff --git a/channeldb/migration21/current/current_codec.go b/channeldb/migration21/current/current_codec.go new file mode 100644 index 00000000..0257f40a --- /dev/null +++ b/channeldb/migration21/current/current_codec.go @@ -0,0 +1,420 @@ +package current + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + // Big endian is the preferred byte order, due to cursor scans over + // integer keys iterating in order. + byteOrder = binary.BigEndian +) + +// writeOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func writeOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// readOutpoint reads an outpoint from the passed reader that was previously +// written using the writeOutpoint struct. +func readOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case keychain.KeyDescriptor: + if err := binary.Write(w, byteOrder, e.Family); err != nil { + return err + } + if err := binary.Write(w, byteOrder, e.Index); err != nil { + return err + } + + if e.PubKey != nil { + if err := binary.Write(w, byteOrder, true); err != nil { + return fmt.Errorf("error writing serialized element: %s", err) + } + + return WriteElement(w, e.PubKey) + } + + return binary.Write(w, byteOrder, false) + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wire.OutPoint: + return writeOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PrivateKey: + b := e.Serialize() + if _, err := w.Write(b); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case shachain.Producer: + return e.Encode(w) + + case shachain.Store: + return e.Encode(w) + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + var msgBuf bytes.Buffer + if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { + return err + } + + msgLen := uint16(len(msgBuf.Bytes())) + if err := WriteElements(w, msgLen); err != nil { + return err + } + + if _, err := w.Write(msgBuf.Bytes()); err != nil { + return err + } + + case common.ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case lnwire.FundingFlag: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *keychain.KeyDescriptor: + if err := binary.Read(r, byteOrder, &e.Family); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &e.Index); err != nil { + return err + } + + var hasPubKey bool + if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { + return err + } + + if hasPubKey { + return ReadElement(r, &e.PubKey) + } + + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return readOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PrivateKey: + var b [btcec.PrivKeyBytesLen]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), b[:]) + *e = priv + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + case *shachain.Producer: + var root [32]byte + if _, err := io.ReadFull(r, root[:]); err != nil { + return err + } + + // TODO(roasbeef): remove + producer, err := shachain.NewRevocationProducerFromBytes(root[:]) + if err != nil { + return err + } + + *e = producer + + case *shachain.Store: + store, err := shachain.NewRevocationStoreFromBytes(r) + if err != nil { + return err + } + + *e = store + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + var msgLen uint16 + if err := ReadElement(r, &msgLen); err != nil { + return err + } + + msgReader := io.LimitReader(r, int64(msgLen)) + msg, err := lnwire.ReadMessage(msgReader, 0) + if err != nil { + return err + } + + *e = msg + + case *common.ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *lnwire.FundingFlag: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration21/current/current_encoding.go b/channeldb/migration21/current/current_encoding.go new file mode 100644 index 00000000..3aaf4415 --- /dev/null +++ b/channeldb/migration21/current/current_encoding.go @@ -0,0 +1,728 @@ +package current + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" +) + +func serializeChanCommit(w io.Writer, c *common.ChannelCommitment) error { // nolint: dupl + if err := WriteElements(w, + c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, + c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, + c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, + c.CommitSig, + ); err != nil { + return err + } + + return serializeHtlcs(w, c.Htlcs...) +} + +func SerializeLogUpdates(w io.Writer, logUpdates []common.LogUpdate) error { // nolint: dupl + numUpdates := uint16(len(logUpdates)) + if err := binary.Write(w, byteOrder, numUpdates); err != nil { + return err + } + + for _, diff := range logUpdates { + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) + if err != nil { + return err + } + } + + return nil +} + +func serializeHtlcs(b io.Writer, htlcs ...common.HTLC) error { // nolint: dupl + numHtlcs := uint16(len(htlcs)) + if err := WriteElement(b, numHtlcs); err != nil { + return err + } + + for _, htlc := range htlcs { + if err := WriteElements(b, + htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, + htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob, + htlc.HtlcIndex, htlc.LogIndex, + ); err != nil { + return err + } + } + + return nil +} + +func SerializeCommitDiff(w io.Writer, diff *common.CommitDiff) error { // nolint: dupl + if err := serializeChanCommit(w, &diff.Commitment); err != nil { + return err + } + + if err := WriteElements(w, diff.CommitSig); err != nil { + return err + } + + if err := SerializeLogUpdates(w, diff.LogUpdates); err != nil { + return err + } + + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + + return nil +} + +func deserializeHtlcs(r io.Reader) ([]common.HTLC, error) { // nolint: dupl + var numHtlcs uint16 + if err := ReadElement(r, &numHtlcs); err != nil { + return nil, err + } + + var htlcs []common.HTLC + if numHtlcs == 0 { + return htlcs, nil + } + + htlcs = make([]common.HTLC, numHtlcs) + for i := uint16(0); i < numHtlcs; i++ { + if err := ReadElements(r, + &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, + &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, + &htlcs[i].Incoming, &htlcs[i].OnionBlob, + &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, + ); err != nil { + return htlcs, err + } + } + + return htlcs, nil +} + +func deserializeChanCommit(r io.Reader) (common.ChannelCommitment, error) { // nolint: dupl + var c common.ChannelCommitment + + err := ReadElements(r, + &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, + &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, + &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, + ) + if err != nil { + return c, err + } + + c.Htlcs, err = deserializeHtlcs(r) + if err != nil { + return c, err + } + + return c, nil +} + +func DeserializeLogUpdates(r io.Reader) ([]common.LogUpdate, error) { // nolint: dupl + var numUpdates uint16 + if err := binary.Read(r, byteOrder, &numUpdates); err != nil { + return nil, err + } + + logUpdates := make([]common.LogUpdate, numUpdates) + for i := 0; i < int(numUpdates); i++ { + err := ReadElements(r, + &logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg, + ) + if err != nil { + return nil, err + } + } + return logUpdates, nil +} + +func DeserializeCommitDiff(r io.Reader) (*common.CommitDiff, error) { // nolint: dupl + var ( + d common.CommitDiff + err error + ) + + d.Commitment, err = deserializeChanCommit(r) + if err != nil { + return nil, err + } + + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + commitSig, ok := msg.(*lnwire.CommitSig) + if !ok { + return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ + "read: %T", msg) + } + d.CommitSig = commitSig + + d.LogUpdates, err = DeserializeLogUpdates(r) + if err != nil { + return nil, err + } + + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]common.CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := ReadElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]common.CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := ReadElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + return &d, nil +} + +func SerializeNetworkResult(w io.Writer, n *common.NetworkResult) error { // nolint: dupl + return WriteElements(w, n.Msg, n.Unencrypted, n.IsResolution) +} + +func DeserializeNetworkResult(r io.Reader) (*common.NetworkResult, error) { // nolint: dupl + n := &common.NetworkResult{} + + if err := ReadElements(r, + &n.Msg, &n.Unencrypted, &n.IsResolution, + ); err != nil { + return nil, err + } + + return n, nil +} + +func writeChanConfig(b io.Writer, c *common.ChannelConfig) error { // nolint: dupl + return WriteElements(b, + c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, + c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, + c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, + c.HtlcBasePoint, + ) +} + +func SerializeChannelCloseSummary(w io.Writer, cs *common.ChannelCloseSummary) error { // nolint: dupl + err := WriteElements(w, + cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, + cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, + cs.TimeLockedBalance, cs.CloseType, cs.IsPending, + ) + if err != nil { + return err + } + + // If this is a close channel summary created before the addition of + // the new fields, then we can exit here. + if cs.RemoteCurrentRevocation == nil { + return WriteElements(w, false) + } + + // If fields are present, write boolean to indicate this, and continue. + if err := WriteElements(w, true); err != nil { + return err + } + + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { + return err + } + + if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil { + return err + } + + // The RemoteNextRevocation field is optional, as it's possible for a + // channel to be closed before we learn of the next unrevoked + // revocation point for the remote party. Write a boolen indicating + // whether this field is present or not. + if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { + return err + } + + // Write the field, if present. + if cs.RemoteNextRevocation != nil { + if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { + return err + } + } + + // Write whether the channel sync message is present. + if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { + return err + } + + // Write the channel sync message, if present. + if cs.LastChanSyncMsg != nil { + if err := WriteElements(w, cs.LastChanSyncMsg); err != nil { + return err + } + } + + return nil +} + +func readChanConfig(b io.Reader, c *common.ChannelConfig) error { + return ReadElements(b, + &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, + &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, + &c.MultiSigKey, &c.RevocationBasePoint, + &c.PaymentBasePoint, &c.DelayBasePoint, + &c.HtlcBasePoint, + ) +} + +func DeserializeCloseChannelSummary(r io.Reader) (*common.ChannelCloseSummary, error) { // nolint: dupl + c := &common.ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + var hasNewFields bool + err = ReadElements(r, &hasNewFields) + if err != nil { + return nil, err + } + + // If fields are not present, we can return. + if !hasNewFields { + return c, nil + } + + // Otherwise read the new fields. + if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message then this might not be present. A boolean + // indicating whether the field is present will come first. + var hasRemoteNextRevocation bool + err = ReadElements(r, &hasRemoteNextRevocation) + if err != nil { + return nil, err + } + + // If this field was written, read it. + if hasRemoteNextRevocation { + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil { + return nil, err + } + } + + // Check if we have a channel sync message to read. + var hasChanSyncMsg bool + err = ReadElements(r, &hasChanSyncMsg) + if err == io.EOF { + return c, nil + } else if err != nil { + return nil, err + } + + // If a chan sync message is present, read it. + if hasChanSyncMsg { + // We must pass in reference to a lnwire.Message for the codec + // to support it. + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + + chanSync, ok := msg.(*lnwire.ChannelReestablish) + if !ok { + return nil, errors.New("unable cast db Message to " + + "ChannelReestablish") + } + c.LastChanSyncMsg = chanSync + } + + return c, nil +} + +// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding +// package has potentially been mangled. +var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") + +var ( + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + fwdPackagesKey = []byte("fwd-packages") + + // addBucketKey is the bucket to which all Add log updates are written. + addBucketKey = []byte("add-updates") + + // failSettleBucketKey is the bucket to which all Settle/Fail log + // updates are written. + failSettleBucketKey = []byte("fail-settle-updates") + + // fwdFilterKey is a key used to write the set of Adds that passed + // validation and are to be forwarded to the switch. + // NOTE: The presence of this key within a forwarding package indicates + // that the package has reached FwdStateProcessed. + fwdFilterKey = []byte("fwd-filter-key") + + // ackFilterKey is a key used to access the PkgFilter indicating which + // Adds have received a Settle/Fail. This response may come from a + // number of sources, including: exitHop settle/fails, switch failures, + // chain arbiter interjections, as well as settle/fails from the + // next hop in the route. + ackFilterKey = []byte("ack-filter-key") + + // settleFailFilterKey is a key used to access the PkgFilter indicating + // which Settles/Fails in have been received and processed by the link + // that originally received the Add. + settleFailFilterKey = []byte("settle-fail-filter-key") +) + +func makeLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} + +// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. +func uint16Key(i uint16) []byte { + key := make([]byte, 2) + byteOrder.PutUint16(key, i) + return key +} + +// ChannelPackager is used by a channel to manage the lifecycle of its forwarding +// packages. The packager is tied to a particular source channel ID, allowing it +// to create and edit its own packages. Each packager also has the ability to +// remove fail/settle htlcs that correspond to an add contained in one of +// source's packages. +type ChannelPackager struct { + source lnwire.ShortChannelID +} + +// NewChannelPackager creates a new packager for a single channel. +func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { + return &ChannelPackager{ + source: source, + } +} + +// AddFwdPkg writes a newly locked in forwarding package to disk. +func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *common.FwdPkg) error { // nolint: dupl + fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey) + if err != nil { + return err + } + + source := makeLogKey(fwdPkg.Source.ToUint64()) + sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) + if err != nil { + return err + } + + heightKey := makeLogKey(fwdPkg.Height) + heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) + if err != nil { + return err + } + + // Write ADD updates we received at this commit height. + addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) + if err != nil { + return err + } + + // Write SETTLE/FAIL updates we received at this commit height. + failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) + if err != nil { + return err + } + + for i := range fwdPkg.Adds { + err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) + if err != nil { + return err + } + } + + // Persist the initialized pkg filter, which will be used to determine + // when we can remove this forwarding package from disk. + var ackFilterBuf bytes.Buffer + if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { + return err + } + + for i := range fwdPkg.SettleFails { + err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) + if err != nil { + return err + } + } + + var settleFailFilterBuf bytes.Buffer + err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) + if err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. +func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *common.LogUpdate) error { + var b bytes.Buffer + if err := serializeLogUpdate(&b, htlc); err != nil { + return err + } + + return bkt.Put(uint16Key(idx), b.Bytes()) +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in a map indexed by the +// remote commitment height at which the updates were locked in. +func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*common.FwdPkg, error) { + return loadChannelFwdPkgs(tx, p.source) +} + +// loadChannelFwdPkgs loads all forwarding packages owned by `source`. +func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { // nolint: dupl + fwdPkgBkt := tx.ReadBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil, nil + } + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, nil + } + + var heights []uint64 + if err := sourceBkt.ForEach(func(k, _ []byte) error { + if len(k) != 8 { + return ErrCorruptedFwdPkg + } + + heights = append(heights, byteOrder.Uint64(k)) + + return nil + }); err != nil { + return nil, err + } + + // Load the forwarding package for each retrieved height. + fwdPkgs := make([]*common.FwdPkg, 0, len(heights)) + for _, height := range heights { + fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) + if err != nil { + return nil, err + } + + fwdPkgs = append(fwdPkgs, fwdPkg) + } + + return fwdPkgs, nil +} + +// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the +// appropriate FwdState. +func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, + height uint64) (*common.FwdPkg, error) { + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.NestedReadBucket(heightKey[:]) + if heightBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + // Load ADDs from disk. + addBkt := heightBkt.NestedReadBucket(addBucketKey) + if addBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + adds, err := loadHtlcs(addBkt) + if err != nil { + return nil, err + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + ackFilterReader := bytes.NewReader(ackFilterBytes) + + ackFilter := &common.PkgFilter{} + if err := ackFilter.Decode(ackFilterReader); err != nil { + return nil, err + } + + // Load SETTLE/FAILs from disk. + failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey) + if failSettleBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + failSettles, err := loadHtlcs(failSettleBkt) + if err != nil { + return nil, err + } + + // Load settle fail filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + + settleFailFilter := &common.PkgFilter{} + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return nil, err + } + + // Initialize the fwding package, which always starts in the + // FwdStateLockedIn. We can determine what state the package was left in + // by examining constraints on the information loaded from disk. + fwdPkg := &common.FwdPkg{ + Source: source, + State: common.FwdStateLockedIn, + Height: height, + Adds: adds, + AckFilter: ackFilter, + SettleFails: failSettles, + SettleFailFilter: settleFailFilter, + } + + // Check to see if we have written the set exported filter adds to + // disk. If we haven't, processing of this package was never started, or + // failed during the last attempt. + fwdFilterBytes := heightBkt.Get(fwdFilterKey) + if fwdFilterBytes == nil { + nAdds := uint16(len(adds)) + fwdPkg.FwdFilter = common.NewPkgFilter(nAdds) + return fwdPkg, nil + } + + fwdFilterReader := bytes.NewReader(fwdFilterBytes) + fwdPkg.FwdFilter = &common.PkgFilter{} + if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { + return nil, err + } + + // Otherwise, a complete round of processing was completed, and we + // advance the package to FwdStateProcessed. + fwdPkg.State = common.FwdStateProcessed + + // If every add, settle, and fail has been fully acknowledged, we can + // safely set the package's state to FwdStateCompleted, signalling that + // it can be garbage collected. + if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { + fwdPkg.State = common.FwdStateCompleted + } + + return fwdPkg, nil +} + +// loadHtlcs retrieves all serialized htlcs in a bucket, returning +// them in order of the indexes they were written under. +func loadHtlcs(bkt kvdb.RBucket) ([]common.LogUpdate, error) { + var htlcs []common.LogUpdate + if err := bkt.ForEach(func(_, v []byte) error { + htlc, err := deserializeLogUpdate(bytes.NewReader(v)) + if err != nil { + return err + } + + htlcs = append(htlcs, *htlc) + + return nil + }); err != nil { + return nil, err + } + + return htlcs, nil +} + +// serializeLogUpdate writes a log update to the provided io.Writer. +func serializeLogUpdate(w io.Writer, l *common.LogUpdate) error { + return WriteElements(w, l.LogIndex, l.UpdateMsg) +} + +// deserializeLogUpdate reads a log update from the provided io.Reader. +func deserializeLogUpdate(r io.Reader) (*common.LogUpdate, error) { + l := &common.LogUpdate{} + if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil { + return nil, err + } + + return l, nil +} diff --git a/channeldb/migration21/legacy/legacy_codec.go b/channeldb/migration21/legacy/legacy_codec.go new file mode 100644 index 00000000..2bc6a94f --- /dev/null +++ b/channeldb/migration21/legacy/legacy_codec.go @@ -0,0 +1,359 @@ +package legacy + +import ( + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" + "github.com/lightningnetwork/lnd/keychain" +) + +var ( + // Big endian is the preferred byte order, due to cursor scans over + // integer keys iterating in order. + byteOrder = binary.BigEndian +) + +// writeOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func writeOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// readOutpoint reads an outpoint from the passed reader that was previously +// written using the writeOutpoint struct. +func readOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case keychain.KeyDescriptor: + if err := binary.Write(w, byteOrder, e.Family); err != nil { + return err + } + if err := binary.Write(w, byteOrder, e.Index); err != nil { + return err + } + + if e.PubKey != nil { + if err := binary.Write(w, byteOrder, true); err != nil { + return fmt.Errorf("error writing serialized element: %s", err) + } + + return WriteElement(w, e.PubKey) + } + + return binary.Write(w, byteOrder, false) + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case common.ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case wire.OutPoint: + return writeOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + return err + } + + case lnwire.FundingFlag: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return readOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + msg, err := lnwire.ReadMessage(r, 0) + if err != nil { + return err + } + + *e = msg + + case *lnwire.FundingFlag: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *common.ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *keychain.KeyDescriptor: + if err := binary.Read(r, byteOrder, &e.Family); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &e.Index); err != nil { + return err + } + + var hasPubKey bool + if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { + return err + } + + if hasPubKey { + return ReadElement(r, &e.PubKey) + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration21/legacy/legacy_decoding.go b/channeldb/migration21/legacy/legacy_decoding.go new file mode 100644 index 00000000..923cceda --- /dev/null +++ b/channeldb/migration21/legacy/legacy_decoding.go @@ -0,0 +1,737 @@ +package legacy + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" +) + +func deserializeHtlcs(r io.Reader) ([]common.HTLC, error) { + var numHtlcs uint16 + if err := ReadElement(r, &numHtlcs); err != nil { + return nil, err + } + + var htlcs []common.HTLC + if numHtlcs == 0 { + return htlcs, nil + } + + htlcs = make([]common.HTLC, numHtlcs) + for i := uint16(0); i < numHtlcs; i++ { + if err := ReadElements(r, + &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, + &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, + &htlcs[i].Incoming, &htlcs[i].OnionBlob, + &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, + ); err != nil { + return htlcs, err + } + } + + return htlcs, nil +} + +func DeserializeLogUpdates(r io.Reader) ([]common.LogUpdate, error) { + var numUpdates uint16 + if err := binary.Read(r, byteOrder, &numUpdates); err != nil { + return nil, err + } + + logUpdates := make([]common.LogUpdate, numUpdates) + for i := 0; i < int(numUpdates); i++ { + err := ReadElements(r, + &logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg, + ) + if err != nil { + return nil, err + } + } + + return logUpdates, nil +} + +func deserializeChanCommit(r io.Reader) (common.ChannelCommitment, error) { + var c common.ChannelCommitment + + err := ReadElements(r, + &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, + &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, + &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, + ) + if err != nil { + return c, err + } + + c.Htlcs, err = deserializeHtlcs(r) + if err != nil { + return c, err + } + + return c, nil +} + +func DeserializeCommitDiff(r io.Reader) (*common.CommitDiff, error) { + var ( + d common.CommitDiff + err error + ) + + d.Commitment, err = deserializeChanCommit(r) + if err != nil { + return nil, err + } + + d.CommitSig = &lnwire.CommitSig{} + if err := d.CommitSig.Decode(r, 0); err != nil { + return nil, err + } + + d.LogUpdates, err = DeserializeLogUpdates(r) + if err != nil { + return nil, err + } + + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]common.CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := ReadElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]common.CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := ReadElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + return &d, nil +} + +func serializeHtlcs(b io.Writer, htlcs ...common.HTLC) error { + numHtlcs := uint16(len(htlcs)) + if err := WriteElement(b, numHtlcs); err != nil { + return err + } + + for _, htlc := range htlcs { + if err := WriteElements(b, + htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, + htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob, + htlc.HtlcIndex, htlc.LogIndex, + ); err != nil { + return err + } + } + + return nil +} + +func serializeChanCommit(w io.Writer, c *common.ChannelCommitment) error { + if err := WriteElements(w, + c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, + c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, + c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, + c.CommitSig, + ); err != nil { + return err + } + + return serializeHtlcs(w, c.Htlcs...) +} + +func SerializeLogUpdates(w io.Writer, logUpdates []common.LogUpdate) error { + numUpdates := uint16(len(logUpdates)) + if err := binary.Write(w, byteOrder, numUpdates); err != nil { + return err + } + + for _, diff := range logUpdates { + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) + if err != nil { + return err + } + } + + return nil +} + +func SerializeCommitDiff(w io.Writer, diff *common.CommitDiff) error { // nolint: dupl + if err := serializeChanCommit(w, &diff.Commitment); err != nil { + return err + } + + if err := diff.CommitSig.Encode(w, 0); err != nil { + return err + } + + if err := SerializeLogUpdates(w, diff.LogUpdates); err != nil { + return err + } + + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + + return nil +} + +func DeserializeNetworkResult(r io.Reader) (*common.NetworkResult, error) { + var ( + err error + ) + + n := &common.NetworkResult{} + + n.Msg, err = lnwire.ReadMessage(r, 0) + if err != nil { + return nil, err + } + + if err := ReadElements(r, + &n.Unencrypted, &n.IsResolution, + ); err != nil { + return nil, err + } + + return n, nil +} + +func SerializeNetworkResult(w io.Writer, n *common.NetworkResult) error { + if _, err := lnwire.WriteMessage(w, n.Msg, 0); err != nil { + return err + } + + return WriteElements(w, n.Unencrypted, n.IsResolution) +} + +func readChanConfig(b io.Reader, c *common.ChannelConfig) error { // nolint: dupl + return ReadElements(b, + &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, + &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, + &c.MultiSigKey, &c.RevocationBasePoint, + &c.PaymentBasePoint, &c.DelayBasePoint, + &c.HtlcBasePoint, + ) +} + +func DeserializeCloseChannelSummary(r io.Reader) (*common.ChannelCloseSummary, error) { // nolint: dupl + + c := &common.ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + var hasNewFields bool + err = ReadElements(r, &hasNewFields) + if err != nil { + return nil, err + } + + // If fields are not present, we can return. + if !hasNewFields { + return c, nil + } + + // Otherwise read the new fields. + if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message then this might not be present. A boolean + // indicating whether the field is present will come first. + var hasRemoteNextRevocation bool + err = ReadElements(r, &hasRemoteNextRevocation) + if err != nil { + return nil, err + } + + // If this field was written, read it. + if hasRemoteNextRevocation { + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil { + return nil, err + } + } + + // Check if we have a channel sync message to read. + var hasChanSyncMsg bool + err = ReadElements(r, &hasChanSyncMsg) + if err == io.EOF { + return c, nil + } else if err != nil { + return nil, err + } + + // If a chan sync message is present, read it. + if hasChanSyncMsg { + // We must pass in reference to a lnwire.Message for the codec + // to support it. + msg, err := lnwire.ReadMessage(r, 0) + if err != nil { + return nil, err + } + + chanSync, ok := msg.(*lnwire.ChannelReestablish) + if !ok { + return nil, errors.New("unable cast db Message to " + + "ChannelReestablish") + } + c.LastChanSyncMsg = chanSync + } + + return c, nil +} + +func writeChanConfig(b io.Writer, c *common.ChannelConfig) error { // nolint: dupl + return WriteElements(b, + c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, + c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, + c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, + c.HtlcBasePoint, + ) +} + +func SerializeChannelCloseSummary(w io.Writer, cs *common.ChannelCloseSummary) error { + err := WriteElements(w, + cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, + cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, + cs.TimeLockedBalance, cs.CloseType, cs.IsPending, + ) + if err != nil { + return err + } + + // If this is a close channel summary created before the addition of + // the new fields, then we can exit here. + if cs.RemoteCurrentRevocation == nil { + return WriteElements(w, false) + } + + // If fields are present, write boolean to indicate this, and continue. + if err := WriteElements(w, true); err != nil { + return err + } + + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { + return err + } + + if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil { + return err + } + + // The RemoteNextRevocation field is optional, as it's possible for a + // channel to be closed before we learn of the next unrevoked + // revocation point for the remote party. Write a boolen indicating + // whether this field is present or not. + if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { + return err + } + + // Write the field, if present. + if cs.RemoteNextRevocation != nil { + if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { + return err + } + } + + // Write whether the channel sync message is present. + if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { + return err + } + + // Write the channel sync message, if present. + if cs.LastChanSyncMsg != nil { + _, err = lnwire.WriteMessage(w, cs.LastChanSyncMsg, 0) + if err != nil { + return err + } + } + + return nil +} + +// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding +// package has potentially been mangled. +var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") + +var ( + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + fwdPackagesKey = []byte("fwd-packages") + + // addBucketKey is the bucket to which all Add log updates are written. + addBucketKey = []byte("add-updates") + + // failSettleBucketKey is the bucket to which all Settle/Fail log + // updates are written. + failSettleBucketKey = []byte("fail-settle-updates") + + // fwdFilterKey is a key used to write the set of Adds that passed + // validation and are to be forwarded to the switch. + // NOTE: The presence of this key within a forwarding package indicates + // that the package has reached FwdStateProcessed. + fwdFilterKey = []byte("fwd-filter-key") + + // ackFilterKey is a key used to access the PkgFilter indicating which + // Adds have received a Settle/Fail. This response may come from a + // number of sources, including: exitHop settle/fails, switch failures, + // chain arbiter interjections, as well as settle/fails from the + // next hop in the route. + ackFilterKey = []byte("ack-filter-key") + + // settleFailFilterKey is a key used to access the PkgFilter indicating + // which Settles/Fails in have been received and processed by the link + // that originally received the Add. + settleFailFilterKey = []byte("settle-fail-filter-key") +) + +func makeLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} + +// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. +func uint16Key(i uint16) []byte { + key := make([]byte, 2) + byteOrder.PutUint16(key, i) + return key +} + +// ChannelPackager is used by a channel to manage the lifecycle of its forwarding +// packages. The packager is tied to a particular source channel ID, allowing it +// to create and edit its own packages. Each packager also has the ability to +// remove fail/settle htlcs that correspond to an add contained in one of +// source's packages. +type ChannelPackager struct { + source lnwire.ShortChannelID +} + +// NewChannelPackager creates a new packager for a single channel. +func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { + return &ChannelPackager{ + source: source, + } +} + +// AddFwdPkg writes a newly locked in forwarding package to disk. +func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *common.FwdPkg) error { // nolint: dupl + fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey) + if err != nil { + return err + } + + source := makeLogKey(fwdPkg.Source.ToUint64()) + sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) + if err != nil { + return err + } + + heightKey := makeLogKey(fwdPkg.Height) + heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) + if err != nil { + return err + } + + // Write ADD updates we received at this commit height. + addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) + if err != nil { + return err + } + + // Write SETTLE/FAIL updates we received at this commit height. + failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) + if err != nil { + return err + } + + for i := range fwdPkg.Adds { + err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) + if err != nil { + return err + } + } + + // Persist the initialized pkg filter, which will be used to determine + // when we can remove this forwarding package from disk. + var ackFilterBuf bytes.Buffer + if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { + return err + } + + for i := range fwdPkg.SettleFails { + err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) + if err != nil { + return err + } + } + + var settleFailFilterBuf bytes.Buffer + err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) + if err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. +func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *common.LogUpdate) error { + var b bytes.Buffer + if err := serializeLogUpdate(&b, htlc); err != nil { + return err + } + + return bkt.Put(uint16Key(idx), b.Bytes()) +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in a map indexed by the +// remote commitment height at which the updates were locked in. +func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*common.FwdPkg, error) { + return loadChannelFwdPkgs(tx, p.source) +} + +// loadChannelFwdPkgs loads all forwarding packages owned by `source`. +func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { // nolint: dupl + fwdPkgBkt := tx.ReadBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil, nil + } + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, nil + } + + var heights []uint64 + if err := sourceBkt.ForEach(func(k, _ []byte) error { + if len(k) != 8 { + return ErrCorruptedFwdPkg + } + + heights = append(heights, byteOrder.Uint64(k)) + + return nil + }); err != nil { + return nil, err + } + + // Load the forwarding package for each retrieved height. + fwdPkgs := make([]*common.FwdPkg, 0, len(heights)) + for _, height := range heights { + fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) + if err != nil { + return nil, err + } + + fwdPkgs = append(fwdPkgs, fwdPkg) + } + + return fwdPkgs, nil +} + +// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the +// appropriate FwdState. +func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, + height uint64) (*common.FwdPkg, error) { + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.NestedReadBucket(heightKey[:]) + if heightBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + // Load ADDs from disk. + addBkt := heightBkt.NestedReadBucket(addBucketKey) + if addBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + adds, err := loadHtlcs(addBkt) + if err != nil { + return nil, err + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + ackFilterReader := bytes.NewReader(ackFilterBytes) + + ackFilter := &common.PkgFilter{} + if err := ackFilter.Decode(ackFilterReader); err != nil { + return nil, err + } + + // Load SETTLE/FAILs from disk. + failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey) + if failSettleBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + failSettles, err := loadHtlcs(failSettleBkt) + if err != nil { + return nil, err + } + + // Load settle fail filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + + settleFailFilter := &common.PkgFilter{} + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return nil, err + } + + // Initialize the fwding package, which always starts in the + // FwdStateLockedIn. We can determine what state the package was left in + // by examining constraints on the information loaded from disk. + fwdPkg := &common.FwdPkg{ + Source: source, + State: common.FwdStateLockedIn, + Height: height, + Adds: adds, + AckFilter: ackFilter, + SettleFails: failSettles, + SettleFailFilter: settleFailFilter, + } + + // Check to see if we have written the set exported filter adds to + // disk. If we haven't, processing of this package was never started, or + // failed during the last attempt. + fwdFilterBytes := heightBkt.Get(fwdFilterKey) + if fwdFilterBytes == nil { + nAdds := uint16(len(adds)) + fwdPkg.FwdFilter = common.NewPkgFilter(nAdds) + return fwdPkg, nil + } + + fwdFilterReader := bytes.NewReader(fwdFilterBytes) + fwdPkg.FwdFilter = &common.PkgFilter{} + if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { + return nil, err + } + + // Otherwise, a complete round of processing was completed, and we + // advance the package to FwdStateProcessed. + fwdPkg.State = common.FwdStateProcessed + + // If every add, settle, and fail has been fully acknowledged, we can + // safely set the package's state to FwdStateCompleted, signalling that + // it can be garbage collected. + if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { + fwdPkg.State = common.FwdStateCompleted + } + + return fwdPkg, nil +} + +// loadHtlcs retrieves all serialized htlcs in a bucket, returning +// them in order of the indexes they were written under. +func loadHtlcs(bkt kvdb.RBucket) ([]common.LogUpdate, error) { + var htlcs []common.LogUpdate + if err := bkt.ForEach(func(_, v []byte) error { + htlc, err := deserializeLogUpdate(bytes.NewReader(v)) + if err != nil { + return err + } + + htlcs = append(htlcs, *htlc) + + return nil + }); err != nil { + return nil, err + } + + return htlcs, nil +} + +// serializeLogUpdate writes a log update to the provided io.Writer. +func serializeLogUpdate(w io.Writer, l *common.LogUpdate) error { + return WriteElements(w, l.LogIndex, l.UpdateMsg) +} + +// deserializeLogUpdate reads a log update from the provided io.Reader. +func deserializeLogUpdate(r io.Reader) (*common.LogUpdate, error) { + l := &common.LogUpdate{} + if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil { + return nil, err + } + + return l, nil +} diff --git a/channeldb/migration21/migration.go b/channeldb/migration21/migration.go new file mode 100644 index 00000000..2df7b4f8 --- /dev/null +++ b/channeldb/migration21/migration.go @@ -0,0 +1,387 @@ +package migration21 + +import ( + "bytes" + "encoding/binary" + "fmt" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" + "github.com/lightningnetwork/lnd/channeldb/migration21/current" + "github.com/lightningnetwork/lnd/channeldb/migration21/legacy" +) + +var ( + byteOrder = binary.BigEndian + + // openChanBucket stores all the currently open channels. This bucket + // has a second, nested bucket which is keyed by a node's ID. Within + // that node ID bucket, all attributes required to track, update, and + // close a channel are stored. + // + // openChan -> nodeID -> chanPoint + // + // TODO(roasbeef): flesh out comment + openChannelBucket = []byte("open-chan-bucket") + + // commitDiffKey stores the current pending commitment state we've + // extended to the remote party (if any). Each time we propose a new + // state, we store the information necessary to reconstruct this state + // from the prior commitment. This allows us to resync the remote party + // to their expected state in the case of message loss. + // + // TODO(roasbeef): rename to commit chain? + commitDiffKey = []byte("commit-diff-key") + + // unsignedAckedUpdatesKey is an entry in the channel bucket that + // contains the remote updates that we have acked, but not yet signed + // for in one of our remote commits. + unsignedAckedUpdatesKey = []byte("unsigned-acked-updates-key") + + // remoteUnsignedLocalUpdatesKey is an entry in the channel bucket that + // contains the local updates that the remote party has acked, but + // has not yet signed for in one of their local commits. + remoteUnsignedLocalUpdatesKey = []byte("remote-unsigned-local-updates-key") + + // networkResultStoreBucketKey is used for the root level bucket that + // stores the network result for each payment ID. + networkResultStoreBucketKey = []byte("network-result-store-bucket") + + // closedChannelBucket stores summarization information concerning + // previously open, but now closed channels. + closedChannelBucket = []byte("closed-chan-bucket") + + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + fwdPackagesKey = []byte("fwd-packages") +) + +// MigrateDatabaseWireMessages performs a migration in all areas that we +// currently store wire messages without length prefixes. This includes the +// CommitDiff struct, ChannelCloseSummary, LogUpdates, and also the +// networkResult struct as well. +func MigrateDatabaseWireMessages(tx kvdb.RwTx) error { + // The migration will proceed in three phases: we'll need to update any + // pending commit diffs, then any unsigned acked updates for all open + // channels, then finally we'll need to update all the current + // stored network results for payments in the switch. + // + // In this phase, we'll migrate the open channel data. + if err := migrateOpenChanBucket(tx); err != nil { + return err + } + + // Next, we'll update all the present close channel summaries as well. + if err := migrateCloseChanSummaries(tx); err != nil { + return err + } + + // We'll migrate forwarding packages, which have log updates as part of + // their serialized data. + if err := migrateForwardingPackages(tx); err != nil { + return err + } + + // Finally, we'll update the pending network results as well. + return migrateNetworkResults(tx) +} + +func migrateOpenChanBucket(tx kvdb.RwTx) error { + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + + // If no bucket is found, we can exit early. + if openChanBucket == nil { + return nil + } + + type channelPath struct { + nodePub []byte + chainHash []byte + chanPoint []byte + } + var channelPaths []channelPath + err := openChanBucket.ForEach(func(nodePub, v []byte) error { + // Ensure that this is a key the same size as a pubkey, and + // also that it leads directly to a bucket. + if len(nodePub) != 33 || v != nil { + return nil + } + + nodeChanBucket := openChanBucket.NestedReadBucket(nodePub) + if nodeChanBucket == nil { + return fmt.Errorf("no bucket for node %x", nodePub) + } + + // The next layer down is all the chains that this node + // has channels on with us. + return nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + chainBucket := nodeChanBucket.NestedReadBucket( + chainHash, + ) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "bucket for chain=%x", chainHash) + } + + return chainBucket.ForEach(func(chanPoint, v []byte) error { + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + channelPaths = append(channelPaths, channelPath{ + nodePub: nodePub, + chainHash: chainHash, + chanPoint: chanPoint, + }) + + return nil + }) + }) + }) + if err != nil { + return err + } + + // Now that we have all the paths of the channel we need to migrate, + // we'll update all the state in a distinct step to avoid weird + // behavior from modifying buckets in a ForEach statement. + for _, channelPath := range channelPaths { + // First, we'll extract it from the node's chain bucket. + nodeChanBucket := openChanBucket.NestedReadWriteBucket( + channelPath.nodePub, + ) + chainBucket := nodeChanBucket.NestedReadWriteBucket( + channelPath.chainHash, + ) + chanBucket := chainBucket.NestedReadWriteBucket( + channelPath.chanPoint, + ) + + // At this point, we have the channel bucket now, so we'll + // check to see if this channel has a pending commitment or + // not. + commitDiffBytes := chanBucket.Get(commitDiffKey) + if commitDiffBytes != nil { + // Now that we have the commit diff in the _old_ + // encoding, we'll write it back to disk using the new + // encoding which has a length prefix in front of the + // CommitSig. + commitDiff, err := legacy.DeserializeCommitDiff( + bytes.NewReader(commitDiffBytes), + ) + if err != nil { + return err + } + + var b bytes.Buffer + err = current.SerializeCommitDiff(&b, commitDiff) + if err != nil { + return err + } + + err = chanBucket.Put(commitDiffKey, b.Bytes()) + if err != nil { + return err + } + } + + // With the commit diff migrated, we'll now check to see if + // there're any un-acked updates we need to migrate as well. + updateBytes := chanBucket.Get(unsignedAckedUpdatesKey) + if updateBytes != nil { + // We have un-acked updates we need to migrate so we'll + // decode then re-encode them here using the new + // format. + legacyUnackedUpdates, err := legacy.DeserializeLogUpdates( + bytes.NewReader(updateBytes), + ) + if err != nil { + return err + } + + var b bytes.Buffer + err = current.SerializeLogUpdates(&b, legacyUnackedUpdates) + if err != nil { + return err + } + + err = chanBucket.Put(unsignedAckedUpdatesKey, b.Bytes()) + if err != nil { + return err + } + } + + // Remote unsiged updates as well. + updateBytes = chanBucket.Get(remoteUnsignedLocalUpdatesKey) + if updateBytes != nil { + legacyUnsignedUpdates, err := legacy.DeserializeLogUpdates( + bytes.NewReader(updateBytes), + ) + if err != nil { + return err + } + + var b bytes.Buffer + err = current.SerializeLogUpdates(&b, legacyUnsignedUpdates) + if err != nil { + return err + } + + err = chanBucket.Put(remoteUnsignedLocalUpdatesKey, b.Bytes()) + if err != nil { + return err + } + } + } + + return nil +} + +func migrateCloseChanSummaries(tx kvdb.RwTx) error { + closedChanBucket := tx.ReadWriteBucket(closedChannelBucket) + + // Exit early if bucket is not found. + if closedChannelBucket == nil { + return nil + } + + type closedChan struct { + chanKey []byte + summaryBytes []byte + } + var closedChans []closedChan + err := closedChanBucket.ForEach(func(k, v []byte) error { + closedChans = append(closedChans, closedChan{ + chanKey: k, + summaryBytes: v, + }) + return nil + }) + if err != nil { + return err + } + + for _, closedChan := range closedChans { + oldSummary, err := legacy.DeserializeCloseChannelSummary( + bytes.NewReader(closedChan.summaryBytes), + ) + if err != nil { + return err + } + + var newSummaryBytes bytes.Buffer + err = current.SerializeChannelCloseSummary( + &newSummaryBytes, oldSummary, + ) + if err != nil { + return err + } + + err = closedChanBucket.Put( + closedChan.chanKey, newSummaryBytes.Bytes(), + ) + if err != nil { + return err + } + } + return nil +} + +func migrateForwardingPackages(tx kvdb.RwTx) error { + fwdPkgBkt := tx.ReadWriteBucket(fwdPackagesKey) + + // Exit early if bucket is not found. + if fwdPkgBkt == nil { + return nil + } + + // We Go through the bucket and fetches all short channel IDs. + var sources []lnwire.ShortChannelID + err := fwdPkgBkt.ForEach(func(k, v []byte) error { + source := lnwire.NewShortChanIDFromInt(byteOrder.Uint64(k)) + sources = append(sources, source) + return nil + }) + if err != nil { + return err + } + + // Now load all forwading packages using the legacy encoding. + var pkgsToMigrate []*common.FwdPkg + for _, source := range sources { + packager := legacy.NewChannelPackager(source) + fwdPkgs, err := packager.LoadFwdPkgs(tx) + if err != nil { + return err + } + + pkgsToMigrate = append(pkgsToMigrate, fwdPkgs...) + } + + // Add back the packages using the current encoding. + for _, pkg := range pkgsToMigrate { + packager := current.NewChannelPackager(pkg.Source) + err := packager.AddFwdPkg(tx, pkg) + if err != nil { + return err + } + } + + return nil +} + +func migrateNetworkResults(tx kvdb.RwTx) error { + networkResults := tx.ReadWriteBucket(networkResultStoreBucketKey) + + // Exit early if bucket is not found. + if networkResults == nil { + return nil + } + + // Similar to the prior migrations, we'll do this one in two phases: + // we'll first grab all the keys we need to migrate in one loop, then + // update them all in another loop. + var netResultsToMigrate [][2][]byte + err := networkResults.ForEach(func(k, v []byte) error { + netResultsToMigrate = append(netResultsToMigrate, [2][]byte{ + k, v, + }) + return nil + }) + if err != nil { + return err + } + + for _, netResult := range netResultsToMigrate { + resKey := netResult[0] + resBytes := netResult[1] + oldResult, err := legacy.DeserializeNetworkResult( + bytes.NewReader(resBytes), + ) + if err != nil { + return err + } + + var newResultBuf bytes.Buffer + err = current.SerializeNetworkResult(&newResultBuf, oldResult) + if err != nil { + return err + } + + err = networkResults.Put(resKey, newResultBuf.Bytes()) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration21/migration_test.go b/channeldb/migration21/migration_test.go new file mode 100644 index 00000000..f8a159b2 --- /dev/null +++ b/channeldb/migration21/migration_test.go @@ -0,0 +1,469 @@ +package migration21 + +import ( + "bytes" + "fmt" + "math/big" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" + "github.com/lightningnetwork/lnd/channeldb/migration21/current" + "github.com/lightningnetwork/lnd/channeldb/migration21/legacy" + "github.com/lightningnetwork/lnd/channeldb/migtest" +) + +var ( + key = [chainhash.HashSize]byte{ + 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, + 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, + } + + _, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + + wireSig, _ = lnwire.NewSigFromSignature(testSig) + + testSig = &btcec.Signature{ + R: new(big.Int), + S: new(big.Int), + } + _, _ = testSig.R.SetString("63724406601629180062774974542967536251589935445068131219452686511677818569431", 10) + _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) + + testTx = &wire.MsgTx{ + Version: 1, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{}, + Index: 0xffffffff, + }, + SignatureScript: []byte{0x04, 0x31, 0xdc, 0x00, 0x1b, 0x01, 0x62}, + Sequence: 0xffffffff, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 5000000000, + PkScript: []byte{ + 0x41, // OP_DATA_65 + 0x04, 0xd6, 0x4b, 0xdf, 0xd0, 0x9e, 0xb1, 0xc5, + 0xfe, 0x29, 0x5a, 0xbd, 0xeb, 0x1d, 0xca, 0x42, + 0x81, 0xbe, 0x98, 0x8e, 0x2d, 0xa0, 0xb6, 0xc1, + 0xc6, 0xa5, 0x9d, 0xc2, 0x26, 0xc2, 0x86, 0x24, + 0xe1, 0x81, 0x75, 0xe8, 0x51, 0xc9, 0x6b, 0x97, + 0x3d, 0x81, 0xb0, 0x1c, 0xc3, 0x1f, 0x04, 0x78, + 0x34, 0xbc, 0x06, 0xd6, 0xd6, 0xed, 0xf6, 0x20, + 0xd1, 0x84, 0x24, 0x1a, 0x6a, 0xed, 0x8b, 0x63, + 0xa6, // 65-byte signature + 0xac, // OP_CHECKSIG + }, + }, + }, + LockTime: 5, + } + + testCommitDiff = &common.CommitDiff{ + Commitment: common.ChannelCommitment{ + CommitTx: testTx, + CommitSig: make([]byte, 0), + }, + CommitSig: &lnwire.CommitSig{ + ChanID: lnwire.ChannelID(key), + CommitSig: wireSig, + HtlcSigs: []lnwire.Sig{ + wireSig, + wireSig, + }, + }, + LogUpdates: []common.LogUpdate{ + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + }, + }, + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + }, + }, + }, + OpenedCircuitKeys: []common.CircuitKey{}, + ClosedCircuitKeys: []common.CircuitKey{}, + } + + testNetworkResult = &common.NetworkResult{ + Msg: testCommitDiff.CommitSig, + Unencrypted: true, + IsResolution: true, + } + + testChanCloseSummary = &common.ChannelCloseSummary{ + RemotePub: pubKey, + Capacity: 9, + RemoteCurrentRevocation: pubKey, + RemoteNextRevocation: pubKey, + LastChanSyncMsg: &lnwire.ChannelReestablish{ + LocalUnrevokedCommitPoint: pubKey, + }, + } + + netResultKey = []byte{3} + + chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{}) + + adds = []common.LogUpdate{ + { + LogIndex: 0, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: 1, + Amount: 100, + Expiry: 1000, + PaymentHash: [32]byte{0}, + }, + }, + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: 1, + Amount: 101, + Expiry: 1001, + PaymentHash: [32]byte{1}, + }, + }, + } + + settleFails = []common.LogUpdate{ + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateFulfillHTLC{ + ChanID: chanID, + ID: 0, + PaymentPreimage: [32]byte{0}, + }, + }, + { + LogIndex: 3, + UpdateMsg: &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: 1, + Reason: []byte{}, + }, + }, + } +) + +// TestMigrateDatabaseWireMessages tests that we're able to properly migrate +// all the wire messages in the database which are written without a length +// prefix in front of them. At the time this test was written we need to +// migrate three areas: open channel commit diffs, open channel unacked updates, +// and network results in the switch. +func TestMigrateDatabaseWireMessages(t *testing.T) { + + var pub [33]byte + copy(pub[:], key[:]) + + migtest.ApplyMigration( + t, + func(tx kvdb.RwTx) error { + t.Helper() + + // First, we'll insert a new fake channel (well just + // the commitment diff) at the expected location + // on-disk. + openChanBucket, err := tx.CreateTopLevelBucket( + openChannelBucket, + ) + if err != nil { + return err + } + nodeBucket, err := openChanBucket.CreateBucket(pub[:]) + if err != nil { + return err + } + chainBucket, err := nodeBucket.CreateBucket(key[:]) + if err != nil { + return err + } + chanBucket, err := chainBucket.CreateBucket(key[:]) + if err != nil { + return err + } + + var b bytes.Buffer + err = legacy.SerializeCommitDiff(&b, testCommitDiff) + if err != nil { + return err + } + + err = chanBucket.Put(commitDiffKey, b.Bytes()) + if err != nil { + return err + } + + var logUpdateBuf bytes.Buffer + err = legacy.SerializeLogUpdates( + &logUpdateBuf, testCommitDiff.LogUpdates, + ) + if err != nil { + return err + } + + // We'll re-use the same log updates to insert as a set + // of un-acked and unsigned pending log updateas as well. + err = chanBucket.Put( + unsignedAckedUpdatesKey, logUpdateBuf.Bytes(), + ) + if err != nil { + return err + } + + err = chanBucket.Put( + remoteUnsignedLocalUpdatesKey, logUpdateBuf.Bytes(), + ) + if err != nil { + return err + } + + // Next, we'll insert a sample closed channel summary + // for the 2nd part of our migration. + closedChanBucket, err := tx.CreateTopLevelBucket( + closedChannelBucket, + ) + if err != nil { + return err + } + + var summaryBuf bytes.Buffer + err = legacy.SerializeChannelCloseSummary( + &summaryBuf, testChanCloseSummary, + ) + if err != nil { + return err + } + + err = closedChanBucket.Put(key[:], summaryBuf.Bytes()) + if err != nil { + return err + } + + // Create a few forwarding packages to migrate. + for i := uint64(100); i < 200; i++ { + shortChanID := lnwire.NewShortChanIDFromInt(i) + packager := legacy.NewChannelPackager(shortChanID) + fwdPkg := common.NewFwdPkg(shortChanID, 0, adds, settleFails) + + if err := packager.AddFwdPkg(tx, fwdPkg); err != nil { + return err + } + } + + // Finally, we need to insert a sample network result + // as well for the final component of our migration. + var netResBuf bytes.Buffer + err = legacy.SerializeNetworkResult( + &netResBuf, testNetworkResult, + ) + if err != nil { + return err + } + + networkResults, err := tx.CreateTopLevelBucket( + networkResultStoreBucketKey, + ) + if err != nil { + return err + } + + return networkResults.Put( + netResultKey, netResBuf.Bytes(), + ) + }, + func(tx kvdb.RwTx) error { + t.Helper() + + // We'll now read the commit diff from disk using the + // _new_ decoding method. This should match the commit + // diff we inserted in the pre-migration step. + openChanBucket := tx.ReadWriteBucket(openChannelBucket) + nodeBucket := openChanBucket.NestedReadWriteBucket( + pub[:], + ) + chainBucket := nodeBucket.NestedReadWriteBucket(key[:]) + chanBucket := chainBucket.NestedReadWriteBucket(key[:]) + + commitDiffBytes := chanBucket.Get(commitDiffKey) + if commitDiffBytes == nil { + return fmt.Errorf("no commit diff found") + } + + newCommitDiff, err := current.DeserializeCommitDiff( + bytes.NewReader(commitDiffBytes), + ) + if err != nil { + return fmt.Errorf("unable to decode commit "+ + "diff: %v", err) + } + + if !reflect.DeepEqual(newCommitDiff, testCommitDiff) { + return fmt.Errorf("diff mismatch: expected "+ + "%v, got %v", spew.Sdump(testCommitDiff), + spew.Sdump(newCommitDiff)) + } + + // Next, we'll ensure that the un-acked updates match + // up as well. + updateBytes := chanBucket.Get(unsignedAckedUpdatesKey) + if updateBytes == nil { + return fmt.Errorf("no update bytes found") + } + + newUpdates, err := current.DeserializeLogUpdates( + bytes.NewReader(updateBytes), + ) + if err != nil { + return err + } + + if !reflect.DeepEqual( + newUpdates, testCommitDiff.LogUpdates, + ) { + return fmt.Errorf("updates mismatch: expected "+ + "%v, got %v", + spew.Sdump(testCommitDiff.LogUpdates), + spew.Sdump(newUpdates)) + } + + updateBytes = chanBucket.Get(remoteUnsignedLocalUpdatesKey) + if updateBytes == nil { + return fmt.Errorf("no update bytes found") + } + + newUpdates, err = current.DeserializeLogUpdates( + bytes.NewReader(updateBytes), + ) + if err != nil { + return err + } + + if !reflect.DeepEqual( + newUpdates, testCommitDiff.LogUpdates, + ) { + return fmt.Errorf("updates mismatch: expected "+ + "%v, got %v", + spew.Sdump(testCommitDiff.LogUpdates), + spew.Sdump(newUpdates)) + } + + // Next, we'll ensure that the inserted close channel + // summary bytes also mach up with what we inserted in + // the prior step. + closedChanBucket := tx.ReadWriteBucket( + closedChannelBucket, + ) + if closedChannelBucket == nil { + return fmt.Errorf("no closed channels found") + } + + chanSummaryBytes := closedChanBucket.Get(key[:]) + newChanCloseSummary, err := current.DeserializeCloseChannelSummary( + bytes.NewReader(chanSummaryBytes), + ) + if err != nil { + return err + } + + testChanCloseSummary.RemotePub.Curve = nil + testChanCloseSummary.RemoteCurrentRevocation.Curve = nil + testChanCloseSummary.RemoteNextRevocation.Curve = nil + testChanCloseSummary.LastChanSyncMsg.LocalUnrevokedCommitPoint.Curve = nil + + newChanCloseSummary.RemotePub.Curve = nil + newChanCloseSummary.RemoteCurrentRevocation.Curve = nil + newChanCloseSummary.RemoteNextRevocation.Curve = nil + newChanCloseSummary.LastChanSyncMsg.LocalUnrevokedCommitPoint.Curve = nil + + if !reflect.DeepEqual( + newChanCloseSummary, testChanCloseSummary, + ) { + return fmt.Errorf("summary mismatch: expected "+ + "%v, got %v", + spew.Sdump(testChanCloseSummary), + spew.Sdump(newChanCloseSummary)) + } + + // Fetch all forwarding packages. + for i := uint64(100); i < 200; i++ { + shortChanID := lnwire.NewShortChanIDFromInt(i) + packager := current.NewChannelPackager(shortChanID) + + fwdPkgs, err := packager.LoadFwdPkgs(tx) + if err != nil { + return err + } + + if len(fwdPkgs) != 1 { + return fmt.Errorf("expected 1 pkg") + } + + og := common.NewFwdPkg(shortChanID, 0, adds, settleFails) + + // Check that we deserialized the packages correctly. + if !reflect.DeepEqual(fwdPkgs[0], og) { + return fmt.Errorf("res mismatch: expected "+ + "%v, got %v", + spew.Sdump(fwdPkgs[0]), + spew.Sdump(og)) + } + } + + // Finally, we'll check the network results to ensure + // that was migrated properly as well. + networkResults := tx.ReadBucket( + networkResultStoreBucketKey, + ) + if networkResults == nil { + return fmt.Errorf("no net results found") + } + + netResBytes := networkResults.Get(netResultKey) + if netResBytes == nil { + return fmt.Errorf("no network res found") + } + + newNetRes, err := current.DeserializeNetworkResult( + bytes.NewReader(netResBytes), + ) + if err != nil { + return err + } + + if !reflect.DeepEqual(newNetRes, testNetworkResult) { + return fmt.Errorf("res mismatch: expected "+ + "%v, got %v", + spew.Sdump(testNetworkResult), + spew.Sdump(newNetRes)) + } + + return nil + }, + MigrateDatabaseWireMessages, + false, + ) +} diff --git a/channeldb/migration_01_to_11/channel.go b/channeldb/migration_01_to_11/channel.go index e67c0c69..2abdfd28 100644 --- a/channeldb/migration_01_to_11/channel.go +++ b/channeldb/migration_01_to_11/channel.go @@ -12,8 +12,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" ) diff --git a/channeldb/migration_01_to_11/channel_test.go b/channeldb/migration_01_to_11/channel_test.go index 1380828e..7e3ba6e0 100644 --- a/channeldb/migration_01_to_11/channel_test.go +++ b/channeldb/migration_01_to_11/channel_test.go @@ -11,8 +11,8 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" _ "github.com/btcsuite/btcwallet/walletdb/bdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" ) diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go index 1727c8c9..6ee6f608 100644 --- a/channeldb/migration_01_to_11/codec.go +++ b/channeldb/migration_01_to_11/codec.go @@ -10,8 +10,8 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" ) diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go index c7e78e74..9caa6ad8 100644 --- a/channeldb/migration_01_to_11/graph.go +++ b/channeldb/migration_01_to_11/graph.go @@ -14,7 +14,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/channeldb/kvdb" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) var ( diff --git a/channeldb/migration_01_to_11/graph_test.go b/channeldb/migration_01_to_11/graph_test.go index dc21fccf..dc42ba20 100644 --- a/channeldb/migration_01_to_11/graph_test.go +++ b/channeldb/migration_01_to_11/graph_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/btcsuite/btcd/btcec" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) var ( diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go index ceb21a33..b60008ee 100644 --- a/channeldb/migration_01_to_11/invoices.go +++ b/channeldb/migration_01_to_11/invoices.go @@ -9,8 +9,8 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go index fda08226..461e983b 100644 --- a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go +++ b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go @@ -8,8 +8,8 @@ import ( "sort" "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" ) var ( diff --git a/channeldb/migration_01_to_11/migration_11_invoices.go b/channeldb/migration_01_to_11/migration_11_invoices.go index 7cb9ea88..cec7784e 100644 --- a/channeldb/migration_01_to_11/migration_11_invoices.go +++ b/channeldb/migration_01_to_11/migration_11_invoices.go @@ -9,8 +9,8 @@ import ( bitcoinCfg "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11/zpay32" - "github.com/lightningnetwork/lnd/lnwire" litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" ) diff --git a/channeldb/migration_01_to_11/migrations.go b/channeldb/migration_01_to_11/migrations.go index 35be510e..5232628d 100644 --- a/channeldb/migration_01_to_11/migrations.go +++ b/channeldb/migration_01_to_11/migrations.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/lightningnetwork/lnd/channeldb/kvdb" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) // MigrateNodeAndEdgeUpdateIndex is a migration function that will update the diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 3677c90b..c0e57ec0 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -14,8 +14,8 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" ) // TestPaymentStatusesMigration checks that already completed payments will have diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go index e44be003..2ccb8bca 100644 --- a/channeldb/migration_01_to_11/payments.go +++ b/channeldb/migration_01_to_11/payments.go @@ -12,8 +12,8 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/channeldb/migration_01_to_11/payments_test.go b/channeldb/migration_01_to_11/payments_test.go index c5584079..17f7a59a 100644 --- a/channeldb/migration_01_to_11/payments_test.go +++ b/channeldb/migration_01_to_11/payments_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/btcsuite/btcd/btcec" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) var ( diff --git a/channeldb/migration_01_to_11/route.go b/channeldb/migration_01_to_11/route.go index 1dbfff60..2b43eaad 100644 --- a/channeldb/migration_01_to_11/route.go +++ b/channeldb/migration_01_to_11/route.go @@ -11,7 +11,7 @@ import ( "github.com/btcsuite/btcd/btcec" sphinx "github.com/lightningnetwork/lightning-onion" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/tlv" ) diff --git a/channeldb/migration_01_to_11/zpay32/amountunits.go b/channeldb/migration_01_to_11/zpay32/amountunits.go index f53f3ff0..0cc1fcdb 100644 --- a/channeldb/migration_01_to_11/zpay32/amountunits.go +++ b/channeldb/migration_01_to_11/zpay32/amountunits.go @@ -4,7 +4,7 @@ import ( "fmt" "strconv" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) var ( diff --git a/channeldb/migration_01_to_11/zpay32/decode.go b/channeldb/migration_01_to_11/zpay32/decode.go index 07929b45..0803cc9b 100644 --- a/channeldb/migration_01_to_11/zpay32/decode.go +++ b/channeldb/migration_01_to_11/zpay32/decode.go @@ -13,7 +13,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil/bech32" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) // Decode parses the provided encoded invoice and returns a decoded Invoice if diff --git a/channeldb/migration_01_to_11/zpay32/invoice.go b/channeldb/migration_01_to_11/zpay32/invoice.go index dbb991e6..83718c9b 100644 --- a/channeldb/migration_01_to_11/zpay32/invoice.go +++ b/channeldb/migration_01_to_11/zpay32/invoice.go @@ -8,7 +8,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" - "github.com/lightningnetwork/lnd/lnwire" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" ) const ( diff --git a/channeldb/migtest/migtest.go b/channeldb/migtest/migtest.go index d970769e..51a2006e 100644 --- a/channeldb/migtest/migtest.go +++ b/channeldb/migtest/migtest.go @@ -41,6 +41,8 @@ func ApplyMigration(t *testing.T, beforeMigration, afterMigration, migrationFunc func(tx kvdb.RwTx) error, shouldFail bool) { + t.Helper() + cdb, cleanUp, err := MakeDB() defer cleanUp() if err != nil { @@ -55,6 +57,8 @@ func ApplyMigration(t *testing.T, } defer func() { + t.Helper() + if r := recover(); r != nil { err = newError(r) } diff --git a/channeldb/waitingproof_test.go b/channeldb/waitingproof_test.go index ed1bf050..cb5b900f 100644 --- a/channeldb/waitingproof_test.go +++ b/channeldb/waitingproof_test.go @@ -5,6 +5,7 @@ import ( "reflect" + "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/lnwire" ) @@ -23,6 +24,7 @@ func TestWaitingProofStore(t *testing.T) { proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ NodeSignature: wireSig, BitcoinSignature: wireSig, + ExtraOpaqueData: make([]byte, 0), }) store, err := NewWaitingProofStore(db) @@ -40,7 +42,8 @@ func TestWaitingProofStore(t *testing.T) { t.Fatalf("unable retrieve proof from storage: %v", err) } if !reflect.DeepEqual(proof1, proof2) { - t.Fatal("wrong proof retrieved") + t.Fatalf("wrong proof retrieved: expected %v, got %v", + spew.Sdump(proof1), spew.Sdump(proof2)) } if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { diff --git a/discovery/message_store_test.go b/discovery/message_store_test.go index 7d0fa1c2..12286e39 100644 --- a/discovery/message_store_test.go +++ b/discovery/message_store_test.go @@ -64,13 +64,15 @@ func randCompressedPubKey(t *testing.T) [33]byte { func randAnnounceSignatures() *lnwire.AnnounceSignatures { return &lnwire.AnnounceSignatures{ - ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ExtraOpaqueData: make([]byte, 0), } } func randChannelUpdate() *lnwire.ChannelUpdate { return &lnwire.ChannelUpdate{ - ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()), + ExtraOpaqueData: make([]byte, 0), } } diff --git a/discovery/sync_manager_test.go b/discovery/sync_manager_test.go index a6bcb707..f2113ff5 100644 --- a/discovery/sync_manager_test.go +++ b/discovery/sync_manager_test.go @@ -629,8 +629,10 @@ func assertTransitionToChansSynced(t *testing.T, s *GossipSyncer, peer *mockPeer }, time.Second, 500*time.Millisecond) require.NoError(t, s.ProcessQueryMsg(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 1, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, + Complete: 1, }, nil)) chanSeries := s.cfg.channelSeries.(*mockChannelGraphTimeSeries) diff --git a/discovery/syncer.go b/discovery/syncer.go index 36031f31..de6821e3 100644 --- a/discovery/syncer.go +++ b/discovery/syncer.go @@ -753,7 +753,9 @@ func (g *GossipSyncer) synchronizeChanIDs() (bool, error) { func isLegacyReplyChannelRange(query *lnwire.QueryChannelRange, reply *lnwire.ReplyChannelRange) bool { - return reply.QueryChannelRange == *query + return (reply.ChainHash == query.ChainHash && + reply.FirstBlockHeight == query.FirstBlockHeight && + reply.NumBlocks == query.NumBlocks) } // processChanRangeReply is called each time the GossipSyncer receives a new @@ -773,7 +775,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // The last block should also be. We don't need to check the // intermediate ones because they should already be in sorted // order. - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() if replyLastHeight > queryLastHeight { return fmt.Errorf("reply includes channels for height "+ @@ -832,7 +834,7 @@ func (g *GossipSyncer) processChanRangeReply(msg *lnwire.ReplyChannelRange) erro // Otherwise, we'll look at the reply's height range. default: - replyLastHeight := msg.QueryChannelRange.LastBlockHeight() + replyLastHeight := msg.LastBlockHeight() queryLastHeight := g.curQueryRangeMsg.LastBlockHeight() // TODO(wilmer): This might require some padding if the remote @@ -997,10 +999,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro g.cfg.chainHash) return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, - Complete: 0, - EncodingType: g.cfg.encodingType, - ShortChanIDs: nil, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, + Complete: 0, + EncodingType: g.cfg.encodingType, + ShortChanIDs: nil, }) } @@ -1040,14 +1044,12 @@ func (g *GossipSyncer) replyChanRangeQuery(query *lnwire.QueryChannelRange) erro } return g.cfg.sendToPeerSync(&lnwire.ReplyChannelRange{ - QueryChannelRange: lnwire.QueryChannelRange{ - ChainHash: query.ChainHash, - NumBlocks: numBlocks, - FirstBlockHeight: firstHeight, - }, - Complete: complete, - EncodingType: g.cfg.encodingType, - ShortChanIDs: channelChunk, + ChainHash: query.ChainHash, + NumBlocks: numBlocks, + FirstBlockHeight: firstHeight, + Complete: complete, + EncodingType: g.cfg.encodingType, + ShortChanIDs: channelChunk, }) } diff --git a/discovery/syncer_test.go b/discovery/syncer_test.go index c3ad04f5..40e759f9 100644 --- a/discovery/syncer_test.go +++ b/discovery/syncer_test.go @@ -609,10 +609,9 @@ func TestGossipSyncerQueryChannelRangeWrongChainHash(t *testing.T) { t.Fatalf("expected lnwire.ReplyChannelRange, got %T", msg) } - if msg.QueryChannelRange != *query { - t.Fatalf("wrong query channel range in reply: "+ - "expected: %v\ngot: %v", spew.Sdump(*query), - spew.Sdump(msg.QueryChannelRange)) + if msg.ChainHash != query.ChainHash { + t.Fatalf("wrong chain hash: expected %v got %v", + query.ChainHash, msg.ChainHash) } if msg.Complete != 0 { t.Fatalf("expected complete set to 0, got %v", @@ -1227,34 +1226,13 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { t.Fatalf("unable to generate channel range query: %v", err) } - var replyQueries []*lnwire.QueryChannelRange - if legacy { - // Each reply query is the same as the original query in the - // legacy mode. - replyQueries = []*lnwire.QueryChannelRange{query, query, query} - } else { - // When interpreting block ranges, the first reply should start - // from our requested first block, and the last should end at - // our requested last block. - replyQueries = []*lnwire.QueryChannelRange{ - { - FirstBlockHeight: 0, - NumBlocks: 11, - }, - { - FirstBlockHeight: 11, - NumBlocks: 1, - }, - { - FirstBlockHeight: 12, - NumBlocks: query.NumBlocks - 12, - }, - } - } - + // When interpreting block ranges, the first reply should start from + // our requested first block, and the last should end at our requested + // last block. replies := []*lnwire.ReplyChannelRange{ { - QueryChannelRange: *replyQueries[0], + FirstBlockHeight: 0, + NumBlocks: 11, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 10, @@ -1262,7 +1240,8 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[1], + FirstBlockHeight: 11, + NumBlocks: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 11, @@ -1270,8 +1249,9 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, }, { - QueryChannelRange: *replyQueries[2], - Complete: 1, + FirstBlockHeight: 12, + NumBlocks: query.NumBlocks - 12, + Complete: 1, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: 12, @@ -1280,6 +1260,19 @@ func testGossipSyncerProcessChanRangeReply(t *testing.T, legacy bool) { }, } + // Each reply query is the same as the original query in the legacy + // mode. + if legacy { + replies[0].FirstBlockHeight = query.FirstBlockHeight + replies[0].NumBlocks = query.NumBlocks + + replies[1].FirstBlockHeight = query.FirstBlockHeight + replies[1].NumBlocks = query.NumBlocks + + replies[2].FirstBlockHeight = query.FirstBlockHeight + replies[2].NumBlocks = query.NumBlocks + } + // We'll begin by sending the syncer a set of non-complete channel // range replies. if err := syncer.processChanRangeReply(replies[0]); err != nil { @@ -2377,7 +2370,9 @@ func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { // order to transition the syncer's state. for i := uint32(0); i < syncer.cfg.maxQueryChanRangeReplies; i++ { reply := &lnwire.ReplyChannelRange{ - QueryChannelRange: *query, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: query.FirstBlockHeight + i, @@ -2408,7 +2403,9 @@ func TestGossipSyncerMaxChannelRangeReplies(t *testing.T) { // Finally, attempting to process another reply for the same query // should result in an error. require.Error(t, syncer.ProcessQueryMsg(&lnwire.ReplyChannelRange{ - QueryChannelRange: *query, + ChainHash: query.ChainHash, + FirstBlockHeight: query.FirstBlockHeight, + NumBlocks: query.NumBlocks, ShortChanIDs: []lnwire.ShortChannelID{ { BlockHeight: query.LastBlockHeight() + 1, diff --git a/htlcswitch/payment_result.go b/htlcswitch/payment_result.go index cd5fe0a5..06345aff 100644 --- a/htlcswitch/payment_result.go +++ b/htlcswitch/payment_result.go @@ -61,28 +61,15 @@ type networkResult struct { // serializeNetworkResult serializes the networkResult. func serializeNetworkResult(w io.Writer, n *networkResult) error { - if _, err := lnwire.WriteMessage(w, n.msg, 0); err != nil { - return err - } - - return channeldb.WriteElements(w, n.unencrypted, n.isResolution) + return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution) } // deserializeNetworkResult deserializes the networkResult. func deserializeNetworkResult(r io.Reader) (*networkResult, error) { - var ( - err error - ) - n := &networkResult{} - n.msg, err = lnwire.ReadMessage(r, 0) - if err != nil { - return nil, err - } - if err := channeldb.ReadElements(r, - &n.unencrypted, &n.isResolution, + &n.msg, &n.unencrypted, &n.isResolution, ); err != nil { return nil, err } diff --git a/htlcswitch/payment_result_test.go b/htlcswitch/payment_result_test.go index 04ff57d8..aa7cbc17 100644 --- a/htlcswitch/payment_result_test.go +++ b/htlcswitch/payment_result_test.go @@ -39,18 +39,21 @@ func TestNetworkResultSerialization(t *testing.T) { ChanID: chanID, ID: 2, PaymentPreimage: preimage, + ExtraData: make([]byte, 0), } fail := &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: []byte{}, + ChanID: chanID, + ID: 1, + Reason: []byte{}, + ExtraData: make([]byte, 0), } fail2 := &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: reason[:], + ChanID: chanID, + ID: 1, + Reason: reason[:], + ExtraData: make([]byte, 0), } testCases := []*networkResult{ diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index ead91598..27ea6ac9 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -3176,6 +3176,7 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, + ExtraData: make([]byte, 0), } htlcIndex, err := bobChannel.AddHTLC(h, nil) @@ -3220,6 +3221,7 @@ func TestChanSyncOweCommitment(t *testing.T) { Amount: htlcAmt, Expiry: uint32(10), OnionBlob: fakeOnionBlob, + ExtraData: make([]byte, 0), } aliceHtlcIndex, err := aliceChannel.AddHTLC(aliceHtlc, nil) if err != nil { diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index da9daa69..57f2ad40 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -1,6 +1,7 @@ package lnwire import ( + "fmt" "io" "github.com/btcsuite/btcd/btcec" @@ -92,6 +93,17 @@ type AcceptChannel struct { // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + // + // NOTE: Since the upfront shutdown script MUST be present (though can + // be zero-length) if any TLV data is available, the script will be + // extracted and removed from this blob when decoding. ExtraData will + // contain all TLV records _except_ the DeliveryAddress record in that + // case. + ExtraData ExtraOpaqueData } // A compile time check to ensure AcceptChannel implements the lnwire.Message @@ -104,6 +116,15 @@ var _ Message = (*AcceptChannel)(nil) // // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { + // Since the upfront script is encoded as a TLV record, concatenate it + // with the ExtraData, and write them as one. + tlvRecords, err := packShutdownScript( + a.UpfrontShutdownScript, a.ExtraData, + ) + if err != nil { + return err + } + return WriteElements(w, a.PendingChannelID[:], a.DustLimit, @@ -119,7 +140,7 @@ func (a *AcceptChannel) Encode(w io.Writer, pver uint32) error { a.DelayedPaymentPoint, a.HtlcPoint, a.FirstCommitmentPoint, - a.UpfrontShutdownScript, + tlvRecords, ) } @@ -150,15 +171,82 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { return err } - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err = ReadElement(r, &a.UpfrontShutdownScript) - if err != nil && err != io.EOF { + // For backwards compatibility, the optional extra data blob for + // AcceptChannel must contain an entry for the upfront shutdown script. + // We'll read it out and attempt to parse it. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { return err } + + a.UpfrontShutdownScript, a.ExtraData, err = parseShutdownScript( + tlvRecords, + ) + if err != nil { + return err + } + return nil } +// packShutdownScript takes an upfront shutdown script and an opaque data blob +// and concatenates them. +func packShutdownScript(addr DeliveryAddress, extraData ExtraOpaqueData) ( + ExtraOpaqueData, error) { + + // We'll always write the upfront shutdown script record, regardless of + // the script being empty. + var tlvRecords ExtraOpaqueData + + // Pack it into a data blob as a TLV record. + err := tlvRecords.PackRecords(addr.NewRecord()) + if err != nil { + return nil, fmt.Errorf("unable to pack upfront shutdown "+ + "script as TLV record: %v", err) + } + + // Concatenate the remaining blob with the shutdown script record. + tlvRecords = append(tlvRecords, extraData...) + return tlvRecords, nil +} + +// parseShutdownScript reads and extract the upfront shutdown script from the +// passe data blob. It returns the script, if any, and the remainder of the +// data blob. +// +// This can be used to parse extra data for the OpenChannel and AcceptChannel +// messages, where the shutdown script is mandatory if extra TLV data is +// present. +func parseShutdownScript(tlvRecords ExtraOpaqueData) (DeliveryAddress, + ExtraOpaqueData, error) { + + // If no TLV data is present there can't be any script available. + if len(tlvRecords) == 0 { + return nil, tlvRecords, nil + } + + // Otherwise the shutdown script MUST be present. + var addr DeliveryAddress + tlvs, err := tlvRecords.ExtractRecords(addr.NewRecord()) + if err != nil { + return nil, nil, err + } + + // Not among TLV records, this means the data was invalid. + if _, ok := tlvs[DeliveryAddrType]; !ok { + return nil, nil, fmt.Errorf("no shutdown script in non-empty " + + "data blob") + } + + // Now that we have retrieved the address (which can be zero-length), + // we'll remove the bytes encoding it from the TLV data before + // returning it. + addrLen := len(addr) + tlvRecords = tlvRecords[addrLen+2:] + + return addr, tlvRecords, nil +} + // MsgType returns the MessageType code which uniquely identifies this message // as an AcceptChannel on the wire. // @@ -172,11 +260,5 @@ func (a *AcceptChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AcceptChannel) MaxPayloadLength(uint32) uint32 { - // 32 + (8 * 4) + (4 * 1) + (2 * 2) + (33 * 6) - var length uint32 = 270 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/announcement_signatures.go b/lnwire/announcement_signatures.go index 639704de..14341392 100644 --- a/lnwire/announcement_signatures.go +++ b/lnwire/announcement_signatures.go @@ -2,7 +2,6 @@ package lnwire import ( "io" - "io/ioutil" ) // AnnounceSignatures is a direct message between two endpoints of a @@ -40,7 +39,7 @@ type AnnounceSignatures struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure AnnounceSignatures implements the @@ -52,29 +51,13 @@ var _ Message = (*AnnounceSignatures)(nil) // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, &a.ChannelID, &a.ShortChannelID, &a.NodeSignature, &a.BitcoinSignature, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target AnnounceSignatures into the passed io.Writer @@ -104,5 +87,5 @@ func (a *AnnounceSignatures) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *AnnounceSignatures) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } diff --git a/lnwire/channel_announcement.go b/lnwire/channel_announcement.go index 46efeed8..de4b72b3 100644 --- a/lnwire/channel_announcement.go +++ b/lnwire/channel_announcement.go @@ -3,7 +3,6 @@ package lnwire import ( "bytes" "io" - "io/ioutil" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -56,7 +55,7 @@ type ChannelAnnouncement struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure ChannelAnnouncement implements the @@ -68,7 +67,7 @@ var _ Message = (*ChannelAnnouncement)(nil) // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, &a.NodeSig1, &a.NodeSig2, &a.BitcoinSig1, @@ -80,24 +79,8 @@ func (a *ChannelAnnouncement) Decode(r io.Reader, pver uint32) error { &a.NodeID2, &a.BitcoinKey1, &a.BitcoinKey2, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target ChannelAnnouncement into the passed io.Writer @@ -134,7 +117,7 @@ func (a *ChannelAnnouncement) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelAnnouncement) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign is used to retrieve part of the announcement message which should diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index 6fa8f8ac..bfe7c53a 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -60,6 +60,11 @@ type ChannelReestablish struct { // LocalUnrevokedCommitPoint is the commitment point used in the // current un-revoked commitment transaction of the sending party. LocalUnrevokedCommitPoint *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure ChannelReestablish implements the @@ -83,12 +88,20 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // If the commit point wasn't sent, then we won't write out any of the // remaining fields as they're optional. if a.LocalUnrevokedCommitPoint == nil { - return nil + // However, we'll still write out the extra data if it's + // present. + // + // NOTE: This is here primarily for the quickcheck tests, in + // practice, we'll always populate this field. + return WriteElements(w, a.ExtraData) } // Otherwise, we'll write out the remaining elements. - return WriteElements(w, a.LastRemoteCommitSecret[:], - a.LocalUnrevokedCommitPoint) + return WriteElements(w, + a.LastRemoteCommitSecret[:], + a.LocalUnrevokedCommitPoint, + a.ExtraData, + ) } // Decode deserializes a serialized ChannelReestablish stored in the passed @@ -118,6 +131,9 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { var buf [32]byte _, err = io.ReadFull(r, buf[:32]) if err == io.EOF { + // If there aren't any more bytes, then we'll emplace an empty + // extra data to make our quickcheck tests happy. + a.ExtraData = make([]byte, 0) return nil } else if err != nil { return err @@ -129,7 +145,11 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { // We'll conclude by parsing out the commitment point. We don't check // the error in this case, as it has included the commit secret, then // they MUST also include the commit point. - return ReadElement(r, &a.LocalUnrevokedCommitPoint) + if err = ReadElement(r, &a.LocalUnrevokedCommitPoint); err != nil { + return err + } + + return a.ExtraData.Decode(r) } // MsgType returns the integer uniquely identifying this message type on the @@ -145,22 +165,5 @@ func (a *ChannelReestablish) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 { - var length uint32 - - // ChanID - 32 bytes - length += 32 - - // NextLocalCommitHeight - 8 bytes - length += 8 - - // RemoteCommitTailHeight - 8 bytes - length += 8 - - // LastRemoteCommitSecret - 32 bytes - length += 32 - - // LocalUnrevokedCommitPoint - 33 bytes - length += 33 - - return length + return MaxMsgBody } diff --git a/lnwire/channel_update.go b/lnwire/channel_update.go index 037f3d55..80000056 100644 --- a/lnwire/channel_update.go +++ b/lnwire/channel_update.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "github.com/btcsuite/btcd/chaincfg/chainhash" ) @@ -115,13 +114,10 @@ type ChannelUpdate struct { // HtlcMaximumMsat is the maximum HTLC value which will be accepted. HtlcMaximumMsat MilliSatoshi - // ExtraOpaqueData is the set of data that was appended to this - // message, some of which we may not actually know how to iterate or - // parse. By holding onto this data, we ensure that we're able to - // properly validate the set of signatures that cover these new fields, - // and ensure we're able to make upgrades to the network in a forwards - // compatible manner. - ExtraOpaqueData []byte + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure ChannelUpdate implements the lnwire.Message @@ -156,19 +152,7 @@ func (a *ChannelUpdate) Decode(r io.Reader, pver uint32) error { } } - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil + return a.ExtraOpaqueData.Decode(r) } // Encode serializes the target ChannelUpdate into the passed io.Writer @@ -201,7 +185,7 @@ func (a *ChannelUpdate) Encode(w io.Writer, pver uint32) error { } // Finally, append any extra opaque data. - return WriteElements(w, a.ExtraOpaqueData) + return a.ExtraOpaqueData.Encode(w) } // MsgType returns the integer uniquely identifying this message type on the @@ -217,7 +201,7 @@ func (a *ChannelUpdate) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *ChannelUpdate) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign is used to retrieve part of the announcement message which should @@ -250,7 +234,7 @@ func (a *ChannelUpdate) DataToSign() ([]byte, error) { } // Finally, append any extra opaque data. - if err := WriteElements(&w, a.ExtraOpaqueData); err != nil { + if err := a.ExtraOpaqueData.Encode(&w); err != nil { return nil, err } diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 91b90646..7732715b 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -27,6 +27,11 @@ type ClosingSigned struct { // Signature is for the proposed channel close transaction. Signature Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewClosingSigned creates a new empty ClosingSigned message. @@ -49,7 +54,9 @@ var _ Message = (*ClosingSigned)(nil) // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &c.ChannelID, &c.FeeSatoshis, &c.Signature) + return ReadElements( + r, &c.ChannelID, &c.FeeSatoshis, &c.Signature, &c.ExtraData, + ) } // Encode serializes the target ClosingSigned into the passed io.Writer @@ -57,7 +64,9 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, c.ChannelID, c.FeeSatoshis, c.Signature) + return WriteElements( + w, c.ChannelID, c.FeeSatoshis, c.Signature, c.ExtraData, + ) } // MsgType returns the integer uniquely identifying this message type on the @@ -73,16 +82,5 @@ func (c *ClosingSigned) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // ChannelID - 32 bytes - length += 32 - - // FeeSatoshis - 8 bytes - length += 8 - - // Signature - 64 bytes - length += 64 - - return length + return MaxMsgBody } diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index f15a9738..8856389f 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -34,11 +34,18 @@ type CommitSig struct { // should be signed, for each incoming HTLC the HTLC timeout // transaction should be signed. HtlcSigs []Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewCommitSig creates a new empty CommitSig message. func NewCommitSig() *CommitSig { - return &CommitSig{} + return &CommitSig{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure CommitSig implements the lnwire.Message @@ -54,6 +61,7 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { &c.ChanID, &c.CommitSig, &c.HtlcSigs, + &c.ExtraData, ) } @@ -66,6 +74,7 @@ func (c *CommitSig) Encode(w io.Writer, pver uint32) error { c.ChanID, c.CommitSig, c.HtlcSigs, + c.ExtraData, ) } @@ -82,8 +91,7 @@ func (c *CommitSig) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *CommitSig) MaxPayloadLength(uint32) uint32 { - // 32 + 64 + 2 + max_allowed_htlcs - return MaxMessagePayload + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/error.go b/lnwire/error.go index 19911d1f..02c07aea 100644 --- a/lnwire/error.go +++ b/lnwire/error.go @@ -123,8 +123,7 @@ func (c *Error) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *Error) MaxPayloadLength(uint32) uint32 { - // 32 + 2 + 65501 - return MaxMessagePayload + return MaxMsgBody } // isASCII is a helper method that checks whether all bytes in `data` would be diff --git a/lnwire/extra_bytes.go b/lnwire/extra_bytes.go new file mode 100644 index 00000000..22fd20bd --- /dev/null +++ b/lnwire/extra_bytes.go @@ -0,0 +1,84 @@ +package lnwire + +import ( + "bytes" + "io" + "io/ioutil" + + "github.com/lightningnetwork/lnd/tlv" +) + +// ExtraOpaqueData is the set of data that was appended to this message, some +// of which we may not actually know how to iterate or parse. By holding onto +// this data, we ensure that we're able to properly validate the set of +// signatures that cover these new fields, and ensure we're able to make +// upgrades to the network in a forwards compatible manner. +type ExtraOpaqueData []byte + +// Encode attempts to encode the raw extra bytes into the passed io.Writer. +func (e *ExtraOpaqueData) Encode(w io.Writer) error { + eBytes := []byte((*e)[:]) + if err := WriteElements(w, eBytes); err != nil { + return err + } + + return nil +} + +// Decode attempts to unpack the raw bytes encoded in the passed io.Reader as a +// set of extra opaque data. +func (e *ExtraOpaqueData) Decode(r io.Reader) error { + // First, we'll attempt to read a set of bytes contained within the + // passed io.Reader (if any exist). + rawBytes, err := ioutil.ReadAll(r) + if err != nil { + return err + } + + // If we _do_ have some bytes, then we'll swap out our backing pointer. + // This ensures that any struct that embeds this type will properly + // store the bytes once this method exits. + if len(rawBytes) > 0 { + *e = ExtraOpaqueData(rawBytes) + } else { + *e = make([]byte, 0) + } + + return nil +} + +// PackRecords attempts to encode the set of tlv records into the target +// ExtraOpaqueData instance. The records will be encoded as a raw TLV stream +// and stored within the backing slice pointer. +func (e *ExtraOpaqueData) PackRecords(records ...tlv.Record) error { + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + var extraBytesWriter bytes.Buffer + if err := tlvStream.Encode(&extraBytesWriter); err != nil { + return err + } + + *e = ExtraOpaqueData(extraBytesWriter.Bytes()) + + return nil +} + +// ExtractRecords attempts to decode any types in the internal raw bytes as if +// it were a tlv stream. The set of raw parsed types is returned, and any +// passed records (if found in the stream) will be parsed into the proper +// tlv.Record. +func (e *ExtraOpaqueData) ExtractRecords(records ...tlv.Record) ( + tlv.TypeMap, error) { + + extraBytesReader := bytes.NewReader(*e) + + tlvStream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + return tlvStream.DecodeWithParsedTypes(extraBytesReader) +} diff --git a/lnwire/extra_bytes_test.go b/lnwire/extra_bytes_test.go new file mode 100644 index 00000000..39271d6a --- /dev/null +++ b/lnwire/extra_bytes_test.go @@ -0,0 +1,147 @@ +package lnwire + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TestExtraOpaqueDataEncodeDecode tests that we're able to encode/decode +// arbitrary payloads. +func TestExtraOpaqueDataEncodeDecode(t *testing.T) { + t.Parallel() + + type testCase struct { + // emptyBytes indicates if we should try to encode empty bytes + // or not. + emptyBytes bool + + // inputBytes if emptyBytes is false, then we'll read in this + // set of bytes instead. + inputBytes []byte + } + + // We should be able to read in an arbitrary set of bytes as an + // ExtraOpaqueData, then encode those new bytes into a new instance. + // The final two instances should be identical. + scenario := func(test testCase) bool { + var ( + extraData ExtraOpaqueData + b bytes.Buffer + ) + + copy(extraData[:], test.inputBytes) + + if err := extraData.Encode(&b); err != nil { + t.Fatalf("unable to encode extra data: %v", err) + return false + } + + var newBytes ExtraOpaqueData + if err := newBytes.Decode(&b); err != nil { + t.Fatalf("unable to decode extra bytes: %v", err) + return false + } + + if !bytes.Equal(extraData[:], newBytes[:]) { + t.Fatalf("expected %x, got %x", extraData, + newBytes) + return false + } + + return true + } + + // We'll make a function to generate random test data. Half of the + // time, we'll actually feed in blank bytes. + quickCfg := &quick.Config{ + Values: func(v []reflect.Value, r *rand.Rand) { + + var newTestCase testCase + if r.Int31()%2 == 0 { + newTestCase.emptyBytes = true + } + + if !newTestCase.emptyBytes { + numBytes := r.Int31n(1000) + newTestCase.inputBytes = make([]byte, numBytes) + + _, err := r.Read(newTestCase.inputBytes) + if err != nil { + t.Fatalf("unable to gen random bytes: %v", err) + return + } + } + + v[0] = reflect.ValueOf(newTestCase) + }, + } + + if err := quick.Check(scenario, quickCfg); err != nil { + t.Fatalf("encode+decode test failed: %v", err) + } +} + +// TestExtraOpaqueDataPackUnpackRecords tests that we're able to pack a set of +// tlv.Records into a stream, and unpack them on the other side to obtain the +// same set of records. +func TestExtraOpaqueDataPackUnpackRecords(t *testing.T) { + t.Parallel() + + var ( + type1 tlv.Type = 1 + type2 tlv.Type = 2 + + channelType1 uint8 = 2 + channelType2 uint8 + + hop1 uint32 = 99 + hop2 uint32 + ) + testRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType1), + tlv.MakePrimitiveRecord(type2, &hop1), + } + + // Now that we have our set of sample records and types, we'll encode + // them into the passed ExtraOpaqueData instance. + var extraBytes ExtraOpaqueData + if err := extraBytes.PackRecords(testRecords...); err != nil { + t.Fatalf("unable to pack records: %v", err) + } + + // We'll now simulate decoding these types _back_ into records on the + // other side. + newRecords := []tlv.Record{ + tlv.MakePrimitiveRecord(type1, &channelType2), + tlv.MakePrimitiveRecord(type2, &hop2), + } + typeMap, err := extraBytes.ExtractRecords(newRecords...) + if err != nil { + t.Fatalf("unable to extract record: %v", err) + } + + // We should find that the new backing values have been populated with + // the proper value. + switch { + case channelType1 != channelType2: + t.Fatalf("wrong record for channel type: expected %v, got %v", + channelType1, channelType2) + + case hop1 != hop2: + t.Fatalf("wrong record for hop: expected %v, got %v", hop1, + hop2) + } + + // Both types we created above should be found in the type map. + if _, ok := typeMap[type1]; !ok { + t.Fatalf("type1 not found in typeMap") + } + if _, ok := typeMap[type2]; !ok { + t.Fatalf("type2 not found in typeMap") + } +} diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index c14321ec..437b1b6a 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -24,6 +24,11 @@ type FundingCreated struct { // CommitSig is Alice's signature from Bob's version of the commitment // transaction. CommitSig Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure FundingCreated implements the lnwire.Message @@ -36,7 +41,10 @@ var _ Message = (*FundingCreated)(nil) // // This is part of the lnwire.Message interface. func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig) + return WriteElements( + w, f.PendingChannelID[:], f.FundingPoint, f.CommitSig, + f.ExtraData, + ) } // Decode deserializes the serialized FundingCreated stored in the passed @@ -45,7 +53,10 @@ func (f *FundingCreated) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig) + return ReadElements( + r, f.PendingChannelID[:], &f.FundingPoint, &f.CommitSig, + &f.ExtraData, + ) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -61,6 +72,5 @@ func (f *FundingCreated) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (f *FundingCreated) MaxPayloadLength(uint32) uint32 { - // 32 + 32 + 2 + 64 - return 130 + return MaxMsgBody } diff --git a/lnwire/funding_locked.go b/lnwire/funding_locked.go index c441b0be..1eeddfb6 100644 --- a/lnwire/funding_locked.go +++ b/lnwire/funding_locked.go @@ -19,6 +19,11 @@ type FundingLocked struct { // NextPerCommitmentPoint is the secret that can be used to revoke the // next commitment transaction for the channel. NextPerCommitmentPoint *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewFundingLocked creates a new FundingLocked message, populating it with the @@ -27,6 +32,7 @@ func NewFundingLocked(cid ChannelID, npcp *btcec.PublicKey) *FundingLocked { return &FundingLocked{ ChanID: cid, NextPerCommitmentPoint: npcp, + ExtraData: make([]byte, 0), } } @@ -42,7 +48,9 @@ var _ Message = (*FundingLocked)(nil) func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &c.ChanID, - &c.NextPerCommitmentPoint) + &c.NextPerCommitmentPoint, + &c.ExtraData, + ) } // Encode serializes the target FundingLocked message into the passed io.Writer @@ -53,7 +61,9 @@ func (c *FundingLocked) Decode(r io.Reader, pver uint32) error { func (c *FundingLocked) Encode(w io.Writer, pver uint32) error { return WriteElements(w, c.ChanID, - c.NextPerCommitmentPoint) + c.NextPerCommitmentPoint, + c.ExtraData, + ) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -70,14 +80,5 @@ func (c *FundingLocked) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *FundingLocked) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // ChanID - 32 bytes - length += 32 - - // NextPerCommitmentPoint - 33 bytes - length += 33 - - // 65 bytes - return length + return MaxMsgBody } diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index 620f8b37..1ef15568 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -13,6 +13,11 @@ type FundingSigned struct { // CommitSig is Bob's signature for Alice's version of the commitment // transaction. CommitSig Sig + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure FundingSigned implements the lnwire.Message @@ -25,7 +30,7 @@ var _ Message = (*FundingSigned)(nil) // // This is part of the lnwire.Message interface. func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, f.ChanID, f.CommitSig) + return WriteElements(w, f.ChanID, f.CommitSig, f.ExtraData) } // Decode deserializes the serialized FundingSigned stored in the passed @@ -34,7 +39,7 @@ func (f *FundingSigned) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &f.ChanID, &f.CommitSig) + return ReadElements(r, &f.ChanID, &f.CommitSig, &f.ExtraData) } // MsgType returns the uint32 code which uniquely identifies this message as a @@ -50,6 +55,5 @@ func (f *FundingSigned) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (f *FundingSigned) MaxPayloadLength(uint32) uint32 { - // 32 + 64 - return 96 + return MaxMsgBody } diff --git a/lnwire/gossip_timestamp_range.go b/lnwire/gossip_timestamp_range.go index 3c28cd05..fb62e272 100644 --- a/lnwire/gossip_timestamp_range.go +++ b/lnwire/gossip_timestamp_range.go @@ -24,6 +24,11 @@ type GossipTimestampRange struct { // NOT send any announcements that have a timestamp greater than // FirstTimestamp + TimestampRange. TimestampRange uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewGossipTimestampRange creates a new empty GossipTimestampRange message. @@ -44,6 +49,7 @@ func (g *GossipTimestampRange) Decode(r io.Reader, pver uint32) error { g.ChainHash[:], &g.FirstTimestamp, &g.TimestampRange, + &g.ExtraData, ) } @@ -56,6 +62,7 @@ func (g *GossipTimestampRange) Encode(w io.Writer, pver uint32) error { g.ChainHash[:], g.FirstTimestamp, g.TimestampRange, + g.ExtraData, ) } @@ -73,8 +80,5 @@ func (g *GossipTimestampRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (g *GossipTimestampRange) MaxPayloadLength(uint32) uint32 { - // 32 + 4 + 4 - // - // TODO(roasbeef): update to 8 byte timestmaps? - return 40 + return MaxMsgBody } diff --git a/lnwire/init_message.go b/lnwire/init_message.go index e1ddbb01..4e33fbb6 100644 --- a/lnwire/init_message.go +++ b/lnwire/init_message.go @@ -20,6 +20,11 @@ type Init struct { // message, any GlobalFeatures should be merged into the unified // Features field. Features *RawFeatureVector + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewInitMessage creates new instance of init message object. @@ -27,6 +32,7 @@ func NewInitMessage(gf *RawFeatureVector, f *RawFeatureVector) *Init { return &Init{ GlobalFeatures: gf, Features: f, + ExtraData: make([]byte, 0), } } @@ -42,6 +48,7 @@ func (msg *Init) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &msg.GlobalFeatures, &msg.Features, + &msg.ExtraData, ) } @@ -53,6 +60,7 @@ func (msg *Init) Encode(w io.Writer, pver uint32) error { return WriteElements(w, msg.GlobalFeatures, msg.Features, + msg.ExtraData, ) } @@ -69,5 +77,5 @@ func (msg *Init) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (msg *Init) MaxPayloadLength(uint32) uint32 { - return 2 + 2 + maxAllowedSize + 2 + maxAllowedSize + return MaxMsgBody } diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index ca0e449e..c180cad3 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -18,9 +18,16 @@ import ( "github.com/lightningnetwork/lnd/tor" ) -// MaxSliceLength is the maximum allowed length for any opaque byte slices in -// the wire protocol. -const MaxSliceLength = 65535 +const ( + // MaxSliceLength is the maximum allowed length for any opaque byte + // slices in the wire protocol. + MaxSliceLength = 65535 + + // MaxMsgBody is the largest payload any message is allowed to provide. + // This is two less than the MaxSliceLength as each message has a 2 + // byte type that precedes the message body. + MaxMsgBody = 65533 +) // PkScript is simple type definition which represents a raw serialized public // key script. @@ -418,6 +425,10 @@ func WriteElement(w io.Writer, element interface{}) error { if _, err := w.Write(b[:]); err != nil { return err } + + case ExtraOpaqueData: + return e.Encode(w) + default: return fmt.Errorf("unknown type in WriteElement: %T", e) } @@ -824,6 +835,10 @@ func ReadElement(r io.Reader, element interface{}) error { return err } *e = addrBytes[:length] + + case *ExtraOpaqueData: + return e.Decode(r) + default: return fmt.Errorf("unknown type in ReadElement: %T", e) } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 02023b02..f9c48d38 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -376,6 +376,15 @@ func TestLightningWireProtocol(t *testing.T) { req.UpfrontShutdownScript = []byte{} } + // 1/2 chance how having more TLV data after the + // shutdown script. + if r.Intn(2) == 0 { + // TLV type 1 of length 2. + req.ExtraData = []byte{1, 2, 0xff, 0xff} + } else { + req.ExtraData = []byte{} + } + v[0] = reflect.ValueOf(req) }, MsgAcceptChannel: func(v []reflect.Value, r *rand.Rand) { @@ -436,11 +445,21 @@ func TestLightningWireProtocol(t *testing.T) { } else { req.UpfrontShutdownScript = []byte{} } + // 1/2 chance how having more TLV data after the + // shutdown script. + if r.Intn(2) == 0 { + // TLV type 1 of length 2. + req.ExtraData = []byte{1, 2, 0xff, 0xff} + } else { + req.ExtraData = []byte{} + } v[0] = reflect.ValueOf(req) }, MsgFundingCreated: func(v []reflect.Value, r *rand.Rand) { - req := FundingCreated{} + req := FundingCreated{ + ExtraData: make([]byte, 0), + } if _, err := r.Read(req.PendingChannelID[:]); err != nil { t.Fatalf("unable to generate pending chan id: %v", err) @@ -471,7 +490,8 @@ func TestLightningWireProtocol(t *testing.T) { } req := FundingSigned{ - ChanID: ChannelID(c), + ChanID: ChannelID(c), + ExtraData: make([]byte, 0), } req.CommitSig, err = NewSigFromSignature(testSig) if err != nil { @@ -502,6 +522,7 @@ func TestLightningWireProtocol(t *testing.T) { MsgClosingSigned: func(v []reflect.Value, r *rand.Rand) { req := ClosingSigned{ FeeSatoshis: btcutil.Amount(r.Int63()), + ExtraData: make([]byte, 0), } var err error req.Signature, err = NewSigFromSignature(testSig) @@ -570,8 +591,9 @@ func TestLightningWireProtocol(t *testing.T) { MsgChannelAnnouncement: func(v []reflect.Value, r *rand.Rand) { var err error req := ChannelAnnouncement{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), - Features: randRawFeatureVector(r), + ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + Features: randRawFeatureVector(r), + ExtraOpaqueData: make([]byte, 0), } req.NodeSig1, err = NewSigFromSignature(testSig) if err != nil { @@ -643,6 +665,7 @@ func TestLightningWireProtocol(t *testing.T) { G: uint8(r.Int31()), B: uint8(r.Int31()), }, + ExtraOpaqueData: make([]byte, 0), } req.Signature, err = NewSigFromSignature(testSig) if err != nil { @@ -698,6 +721,7 @@ func TestLightningWireProtocol(t *testing.T) { HtlcMaximumMsat: maxHtlc, BaseFee: uint32(r.Int31()), FeeRate: uint32(r.Int31()), + ExtraOpaqueData: make([]byte, 0), } req.Signature, err = NewSigFromSignature(testSig) if err != nil { @@ -726,7 +750,8 @@ func TestLightningWireProtocol(t *testing.T) { MsgAnnounceSignatures: func(v []reflect.Value, r *rand.Rand) { var err error req := AnnounceSignatures{ - ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ShortChannelID: NewShortChanIDFromInt(uint64(r.Int63())), + ExtraOpaqueData: make([]byte, 0), } req.NodeSignature, err = NewSigFromSignature(testSig) @@ -763,6 +788,7 @@ func TestLightningWireProtocol(t *testing.T) { req := ChannelReestablish{ NextLocalCommitHeight: uint64(r.Int63()), RemoteCommitTailHeight: uint64(r.Int63()), + ExtraData: make([]byte, 0), } // With a 50/50 probability, we'll include the @@ -785,7 +811,9 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgQueryShortChanIDs: func(v []reflect.Value, r *rand.Rand) { - req := QueryShortChanIDs{} + req := QueryShortChanIDs{ + ExtraData: make([]byte, 0), + } // With a 50/50 change, we'll either use zlib encoding, // or regular encoding. @@ -810,10 +838,9 @@ func TestLightningWireProtocol(t *testing.T) { }, MsgReplyChannelRange: func(v []reflect.Value, r *rand.Rand) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: uint32(r.Int31()), - NumBlocks: uint32(r.Int31()), - }, + FirstBlockHeight: uint32(r.Int31()), + NumBlocks: uint32(r.Int31()), + ExtraData: make([]byte, 0), } if _, err := rand.Read(req.ChainHash[:]); err != nil { diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 35534352..62414d4b 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -5,7 +5,6 @@ import ( "fmt" "image/color" "io" - "io/ioutil" "net" "unicode/utf8" ) @@ -98,7 +97,7 @@ type NodeAnnouncement struct { // properly validate the set of signatures that cover these new fields, // and ensure we're able to make upgrades to the network in a forwards // compatible manner. - ExtraOpaqueData []byte + ExtraOpaqueData ExtraOpaqueData } // A compile time check to ensure NodeAnnouncement implements the @@ -110,7 +109,7 @@ var _ Message = (*NodeAnnouncement)(nil) // // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { - err := ReadElements(r, + return ReadElements(r, &a.Signature, &a.Features, &a.Timestamp, @@ -118,24 +117,8 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { &a.RGBColor, &a.Alias, &a.Addresses, + &a.ExtraOpaqueData, ) - if err != nil { - return err - } - - // Now that we've read out all the fields that we explicitly know of, - // we'll collect the remainder into the ExtraOpaqueData field. If there - // aren't any bytes, then we'll snip off the slice to avoid carrying - // around excess capacity. - a.ExtraOpaqueData, err = ioutil.ReadAll(r) - if err != nil { - return err - } - if len(a.ExtraOpaqueData) == 0 { - a.ExtraOpaqueData = nil - } - - return nil } // Encode serializes the target NodeAnnouncement into the passed io.Writer @@ -167,7 +150,7 @@ func (a *NodeAnnouncement) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (a *NodeAnnouncement) MaxPayloadLength(pver uint32) uint32 { - return 65533 + return MaxMsgBody } // DataToSign returns the part of the message that should be signed. diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 3ec147d1..8c4c131c 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,11 +20,12 @@ var ( testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) testChannelUpdate = ChannelUpdate{ - Signature: sig, - ShortChannelID: NewShortChanIDFromInt(1), - Timestamp: 1, - MessageFlags: 0, - ChannelFlags: 1, + Signature: sig, + ShortChannelID: NewShortChanIDFromInt(1), + Timestamp: 1, + MessageFlags: 0, + ChannelFlags: 1, + ExtraOpaqueData: make([]byte, 0), } ) diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index a165ef75..70dbe790 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -128,6 +128,17 @@ type OpenChannel struct { // and has a length prefix, so a zero will be written if it is not set // and its length followed by the script will be written if it is set. UpfrontShutdownScript DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + // + // NOTE: Since the upfront shutdown script MUST be present (though can + // be zero-length) if any TLV data is available, the script will be + // extracted and removed from this blob when decoding. ExtraData will + // contain all TLV records _except_ the DeliveryAddress record in that + // case. + ExtraData ExtraOpaqueData } // A compile time check to ensure OpenChannel implements the lnwire.Message @@ -140,6 +151,15 @@ var _ Message = (*OpenChannel)(nil) // // This is part of the lnwire.Message interface. func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { + // Since the upfront script is encoded as a TLV record, concatenate it + // with the ExtraData, and write them as one. + tlvRecords, err := packShutdownScript( + o.UpfrontShutdownScript, o.ExtraData, + ) + if err != nil { + return err + } + return WriteElements(w, o.ChainHash[:], o.PendingChannelID[:], @@ -159,7 +179,7 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { o.HtlcPoint, o.FirstCommitmentPoint, o.ChannelFlags, - o.UpfrontShutdownScript, + tlvRecords, ) } @@ -169,7 +189,8 @@ func (o *OpenChannel) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { - if err := ReadElements(r, + // Read all the mandatory fields in the open message. + err := ReadElements(r, o.ChainHash[:], o.PendingChannelID[:], &o.FundingAmount, @@ -188,14 +209,23 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { &o.HtlcPoint, &o.FirstCommitmentPoint, &o.ChannelFlags, - ); err != nil { + ) + if err != nil { return err } - // Check for the optional upfront shutdown script field. If it is not there, - // silence the EOF error. - err := ReadElement(r, &o.UpfrontShutdownScript) - if err != nil && err != io.EOF { + // For backwards compatibility, the optional extra data blob for + // OpenChannel must contain an entry for the upfront shutdown script. + // We'll read it out and attempt to parse it. + var tlvRecords ExtraOpaqueData + if err := ReadElements(r, &tlvRecords); err != nil { + return err + } + + o.UpfrontShutdownScript, o.ExtraData, err = parseShutdownScript( + tlvRecords, + ) + if err != nil { return err } @@ -215,11 +245,5 @@ func (o *OpenChannel) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (o *OpenChannel) MaxPayloadLength(uint32) uint32 { - // (32 * 2) + (8 * 6) + (4 * 1) + (2 * 2) + (33 * 6) + 1 - var length uint32 = 319 // base length - - // Upfront shutdown script max length. - length += 2 + deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/query_channel_range.go b/lnwire/query_channel_range.go index 9546fcd3..3bdb30e5 100644 --- a/lnwire/query_channel_range.go +++ b/lnwire/query_channel_range.go @@ -25,6 +25,11 @@ type QueryChannelRange struct { // NumBlocks is the number of blocks beyond the first block that short // channel ID's should be sent for. NumBlocks uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewQueryChannelRange creates a new empty QueryChannelRange message. @@ -45,6 +50,7 @@ func (q *QueryChannelRange) Decode(r io.Reader, pver uint32) error { q.ChainHash[:], &q.FirstBlockHeight, &q.NumBlocks, + &q.ExtraData, ) } @@ -57,6 +63,7 @@ func (q *QueryChannelRange) Encode(w io.Writer, pver uint32) error { q.ChainHash[:], q.FirstBlockHeight, q.NumBlocks, + q.ExtraData, ) } @@ -73,8 +80,7 @@ func (q *QueryChannelRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (q *QueryChannelRange) MaxPayloadLength(uint32) uint32 { - // 32 + 4 + 4 - return 40 + return MaxMsgBody } // LastBlockHeight returns the last block height covered by the range of a diff --git a/lnwire/query_short_chan_ids.go b/lnwire/query_short_chan_ids.go index 3c2b9948..bae23d9f 100644 --- a/lnwire/query_short_chan_ids.go +++ b/lnwire/query_short_chan_ids.go @@ -81,6 +81,11 @@ type QueryShortChanIDs struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData + // noSort indicates whether or not to sort the short channel ids before // writing them out. // @@ -114,8 +119,11 @@ func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { } q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) + if err != nil { + return err + } - return err + return q.ExtraData.Decode(r) } // decodeShortChanIDs decodes a set of short channel ID's that have been @@ -292,7 +300,12 @@ func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { // Base on our encoding type, we'll write out the set of short channel // ID's. - return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + err = encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) + if err != nil { + return err + } + + return q.ExtraData.Encode(w) } // encodeShortChanIDs encodes the passed short channel ID's into the passed @@ -425,5 +438,5 @@ func (q *QueryShortChanIDs) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { - return MaxMessagePayload + return MaxMsgBody } diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 43060602..5167cc5a 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -1,14 +1,29 @@ package lnwire -import "io" +import ( + "io" + "math" + + "github.com/btcsuite/btcd/chaincfg/chainhash" +) // ReplyChannelRange is the response to the QueryChannelRange message. It // includes the original query, and the next streaming chunk of encoded short // channel ID's as the response. We'll also include a byte that indicates if // this is the last query in the message. type ReplyChannelRange struct { - // QueryChannelRange is the corresponding query to this response. - QueryChannelRange + // ChainHash denotes the target chain that we're trying to synchronize + // channel graph state for. + ChainHash chainhash.Hash + + // FirstBlockHeight is the first block in the query range. The + // responder should send all new short channel IDs from this block + // until this block plus the specified number of blocks. + FirstBlockHeight uint32 + + // NumBlocks is the number of blocks beyond the first block that short + // channel ID's should be sent for. + NumBlocks uint32 // Complete denotes if this is the conclusion of the set of streaming // responses to the original query. @@ -22,6 +37,11 @@ type ReplyChannelRange struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData + // noSort indicates whether or not to sort the short channel ids before // writing them out. // @@ -43,18 +63,22 @@ var _ Message = (*ReplyChannelRange)(nil) // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { - err := c.QueryChannelRange.Decode(r, pver) + err := ReadElements(r, + c.ChainHash[:], + &c.FirstBlockHeight, + &c.NumBlocks, + &c.Complete, + ) if err != nil { return err } - if err := ReadElements(r, &c.Complete); err != nil { + c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) + if err != nil { return err } - c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r) - - return err + return c.ExtraData.Decode(r) } // Encode serializes the target ReplyChannelRange into the passed io.Writer @@ -62,15 +86,22 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { - if err := c.QueryChannelRange.Encode(w, pver); err != nil { + err := WriteElements(w, + c.ChainHash[:], + c.FirstBlockHeight, + c.NumBlocks, + c.Complete, + ) + if err != nil { return err } - if err := WriteElements(w, c.Complete); err != nil { + err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + if err != nil { return err } - return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) + return c.ExtraData.Encode(w) } // MsgType returns the integer uniquely identifying this message type on the @@ -86,5 +117,16 @@ func (c *ReplyChannelRange) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 { - return MaxMessagePayload + return MaxMsgBody +} + +// LastBlockHeight returns the last block height covered by the range of a +// QueryChannelRange message. +func (c *ReplyChannelRange) LastBlockHeight() uint32 { + // Handle overflows by casting to uint64. + lastBlockHeight := uint64(c.FirstBlockHeight) + uint64(c.NumBlocks) - 1 + if lastBlockHeight > math.MaxUint32 { + return math.MaxUint32 + } + return uint32(lastBlockHeight) } diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go index d2c8df68..ff341495 100644 --- a/lnwire/reply_channel_range_test.go +++ b/lnwire/reply_channel_range_test.go @@ -30,7 +30,7 @@ func TestReplyChannelRangeUnsorted(t *testing.T) { var req2 ReplyChannelRange err = req2.Decode(bytes.NewReader(b.Bytes()), 0) if _, ok := err.(ErrUnsortedSIDs); !ok { - t.Fatalf("expected ErrUnsortedSIDs, got: %T", + t.Fatalf("expected ErrUnsortedSIDs, got: %v", err) } }) @@ -67,13 +67,12 @@ func TestReplyChannelRangeEmpty(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { req := ReplyChannelRange{ - QueryChannelRange: QueryChannelRange{ - FirstBlockHeight: 1, - NumBlocks: 2, - }, - Complete: 1, - EncodingType: test.encType, - ShortChanIDs: nil, + FirstBlockHeight: 1, + NumBlocks: 2, + Complete: 1, + EncodingType: test.encType, + ShortChanIDs: nil, + ExtraData: make([]byte, 0), } // First decode the hex string in the test case into a diff --git a/lnwire/reply_short_chan_ids_end.go b/lnwire/reply_short_chan_ids_end.go index d77aa0b5..92f1e8fc 100644 --- a/lnwire/reply_short_chan_ids_end.go +++ b/lnwire/reply_short_chan_ids_end.go @@ -22,6 +22,11 @@ type ReplyShortChanIDsEnd struct { // set of short chan ID's in the corresponding QueryShortChanIDs // message. Complete uint8 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewReplyShortChanIDsEnd creates a new empty ReplyShortChanIDsEnd message. @@ -41,6 +46,7 @@ func (c *ReplyShortChanIDsEnd) Decode(r io.Reader, pver uint32) error { return ReadElements(r, c.ChainHash[:], &c.Complete, + &c.ExtraData, ) } @@ -52,6 +58,7 @@ func (c *ReplyShortChanIDsEnd) Encode(w io.Writer, pver uint32) error { return WriteElements(w, c.ChainHash[:], c.Complete, + c.ExtraData, ) } @@ -69,6 +76,5 @@ func (c *ReplyShortChanIDsEnd) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *ReplyShortChanIDsEnd) MaxPayloadLength(uint32) uint32 { - // 32 (chain hash) + 1 (complete) - return 33 + return MaxMsgBody } diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 0cfa2bc2..b187fae6 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -30,11 +30,18 @@ type RevokeAndAck struct { // create the proper revocation key used within the commitment // transaction. NextRevocationKey *btcec.PublicKey + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewRevokeAndAck creates a new RevokeAndAck message. func NewRevokeAndAck() *RevokeAndAck { - return &RevokeAndAck{} + return &RevokeAndAck{ + ExtraData: make([]byte, 0), + } } // A compile time check to ensure RevokeAndAck implements the lnwire.Message @@ -50,6 +57,7 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { &c.ChanID, c.Revocation[:], &c.NextRevocationKey, + &c.ExtraData, ) } @@ -62,6 +70,7 @@ func (c *RevokeAndAck) Encode(w io.Writer, pver uint32) error { c.ChanID, c.Revocation[:], c.NextRevocationKey, + c.ExtraData, ) } @@ -78,8 +87,7 @@ func (c *RevokeAndAck) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *RevokeAndAck) MaxPayloadLength(uint32) uint32 { - // 32 + 32 + 33 - return 97 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 94d10a90..8def329c 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -15,22 +15,13 @@ type Shutdown struct { // Address is the script to which the channel funds will be paid. Address DeliveryAddress + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } -// DeliveryAddress is used to communicate the address to which funds from a -// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or -// p2wpkh. -type DeliveryAddress []byte - -// deliveryAddressMaxSize is the maximum expected size in bytes of a -// DeliveryAddress based on the types of scripts we know. -// Following are the known scripts and their sizes in bytes. -// - pay to witness script hash: 34 -// - pay to pubkey hash: 25 -// - pay to script hash: 22 -// - pay to witness pubkey hash: 22. -const deliveryAddressMaxSize = 34 - // NewShutdown creates a new Shutdown message. func NewShutdown(cid ChannelID, addr DeliveryAddress) *Shutdown { return &Shutdown{ @@ -48,7 +39,7 @@ var _ Message = (*Shutdown)(nil) // // This is part of the lnwire.Message interface. func (s *Shutdown) Decode(r io.Reader, pver uint32) error { - return ReadElements(r, &s.ChannelID, &s.Address) + return ReadElements(r, &s.ChannelID, &s.Address, &s.ExtraData) } // Encode serializes the target Shutdown into the passed io.Writer observing @@ -56,7 +47,7 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w io.Writer, pver uint32) error { - return WriteElements(w, s.ChannelID, s.Address) + return WriteElements(w, s.ChannelID, s.Address, s.ExtraData) } // MsgType returns the integer uniquely identifying this message type on the @@ -72,16 +63,5 @@ func (s *Shutdown) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (s *Shutdown) MaxPayloadLength(pver uint32) uint32 { - var length uint32 - - // ChannelID - 32bytes - length += 32 - - // Len - 2 bytes - length += 2 - - // ScriptPubKey - maximum delivery address size. - length += deliveryAddressMaxSize - - return length + return MaxMsgBody } diff --git a/lnwire/typed_delivery_addr.go b/lnwire/typed_delivery_addr.go new file mode 100644 index 00000000..9ad53b1a --- /dev/null +++ b/lnwire/typed_delivery_addr.go @@ -0,0 +1,41 @@ +package lnwire + +import ( + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + // DeliveryAddrType is the TLV record type for delivery addreses within + // the name space of the OpenChannel and AcceptChannel messages. + DeliveryAddrType = 0 + + // deliveryAddressMaxSize is the maximum expected size in bytes of a + // DeliveryAddress based on the types of scripts we know. + // Following are the known scripts and their sizes in bytes. + // - pay to witness script hash: 34 + // - pay to pubkey hash: 25 + // - pay to script hash: 22 + // - pay to witness pubkey hash: 22. + deliveryAddressMaxSize = 34 +) + +// DeliveryAddress is used to communicate the address to which funds from a +// closed channel should be sent. The address can be a p2wsh, p2pkh, p2sh or +// p2wpkh. +type DeliveryAddress []byte + +// NewRecord returns a TLV record that can be used to encode the delivery +// address within the ExtraData TLV stream. This was intorudced in order to +// allow the OpenChannel/AcceptChannel messages to properly be extended with +// TLV types. +func (d *DeliveryAddress) NewRecord() tlv.Record { + addrBytes := (*[]byte)(d) + + return tlv.MakeDynamicRecord( + DeliveryAddrType, addrBytes, + func() uint64 { + return uint64(len(*addrBytes)) + }, + tlv.EVarBytes, tlv.DVarBytes, + ) +} diff --git a/lnwire/typed_delivery_addr_test.go b/lnwire/typed_delivery_addr_test.go new file mode 100644 index 00000000..d5d9c703 --- /dev/null +++ b/lnwire/typed_delivery_addr_test.go @@ -0,0 +1,37 @@ +package lnwire + +import ( + "bytes" + "testing" +) + +// TestDeliveryAddressEncodeDecode tests that we're able to properly +// encode and decode delivery addresses within TLV streams. +func TestDeliveryAddressEncodeDecode(t *testing.T) { + t.Parallel() + + addr := DeliveryAddress( + bytes.Repeat([]byte("a"), deliveryAddressMaxSize), + ) + + var extraData ExtraOpaqueData + err := extraData.PackRecords(addr.NewRecord()) + if err != nil { + t.Fatal(err) + } + + var addr2 DeliveryAddress + tlvs, err := extraData.ExtractRecords(addr2.NewRecord()) + if err != nil { + t.Fatal(err) + } + + if _, ok := tlvs[DeliveryAddrType]; !ok { + t.Fatalf("DeliveryAddrType not found in records") + } + + if !bytes.Equal(addr, addr2) { + t.Fatalf("addr mismatch: expected %x, got %x", addr[:], + addr2[:]) + } +} diff --git a/lnwire/update_add_htlc.go b/lnwire/update_add_htlc.go index 028c6320..9211d39f 100644 --- a/lnwire/update_add_htlc.go +++ b/lnwire/update_add_htlc.go @@ -52,6 +52,11 @@ type UpdateAddHTLC struct { // should strip off a layer of encryption, exposing the next hop to be // used in the subsequent UpdateAddHTLC message. OnionBlob [OnionPacketSize]byte + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateAddHTLC returns a new empty UpdateAddHTLC message. @@ -75,6 +80,7 @@ func (c *UpdateAddHTLC) Decode(r io.Reader, pver uint32) error { c.PaymentHash[:], &c.Expiry, c.OnionBlob[:], + &c.ExtraData, ) } @@ -90,6 +96,7 @@ func (c *UpdateAddHTLC) Encode(w io.Writer, pver uint32) error { c.PaymentHash[:], c.Expiry, c.OnionBlob[:], + c.ExtraData, ) } @@ -106,8 +113,7 @@ func (c *UpdateAddHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateAddHTLC) MaxPayloadLength(uint32) uint32 { - // 1450 - return 32 + 8 + 4 + 8 + 32 + 1366 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 194f2ecd..09666ac2 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -26,6 +26,11 @@ type UpdateFailHTLC struct { // failed. This blob is only fully decryptable by the initiator of the // HTLC message. Reason OpaqueReason + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure UpdateFailHTLC implements the lnwire.Message @@ -41,6 +46,7 @@ func (c *UpdateFailHTLC) Decode(r io.Reader, pver uint32) error { &c.ChanID, &c.ID, &c.Reason, + &c.ExtraData, ) } @@ -53,6 +59,7 @@ func (c *UpdateFailHTLC) Encode(w io.Writer, pver uint32) error { c.ChanID, c.ID, c.Reason, + c.ExtraData, ) } @@ -69,21 +76,7 @@ func (c *UpdateFailHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFailHTLC) MaxPayloadLength(uint32) uint32 { - var length uint32 - - // Length of the ChanID - length += 32 - - // Length of the ID - length += 8 - - // Length of the length opaque reason - length += 2 - - // Length of the Reason - length += 292 - - return length + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go index 39d4b870..b28ec29f 100644 --- a/lnwire/update_fail_malformed_htlc.go +++ b/lnwire/update_fail_malformed_htlc.go @@ -24,6 +24,11 @@ type UpdateFailMalformedHTLC struct { // FailureCode the exact reason why onion blob haven't been parsed. FailureCode FailCode + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // A compile time check to ensure UpdateFailMalformedHTLC implements the @@ -40,6 +45,7 @@ func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { &c.ID, c.ShaOnionBlob[:], &c.FailureCode, + &c.ExtraData, ) } @@ -53,6 +59,7 @@ func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { c.ID, c.ShaOnionBlob[:], c.FailureCode, + c.ExtraData, ) } @@ -70,8 +77,7 @@ func (c *UpdateFailMalformedHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFailMalformedHTLC) MaxPayloadLength(uint32) uint32 { - // 32 + 8 + 32 + 2 - return 74 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fee.go b/lnwire/update_fee.go index 2d27c377..25ab180c 100644 --- a/lnwire/update_fee.go +++ b/lnwire/update_fee.go @@ -16,6 +16,11 @@ type UpdateFee struct { // TODO(halseth): make SatPerKWeight when fee estimation is moved to // own package. Currently this will cause an import cycle. FeePerKw uint32 + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateFee creates a new UpdateFee message. @@ -38,6 +43,7 @@ func (c *UpdateFee) Decode(r io.Reader, pver uint32) error { return ReadElements(r, &c.ChanID, &c.FeePerKw, + &c.ExtraData, ) } @@ -49,6 +55,7 @@ func (c *UpdateFee) Encode(w io.Writer, pver uint32) error { return WriteElements(w, c.ChanID, c.FeePerKw, + c.ExtraData, ) } @@ -65,8 +72,7 @@ func (c *UpdateFee) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFee) MaxPayloadLength(uint32) uint32 { - // 32 + 4 - return 36 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is diff --git a/lnwire/update_fulfill_htlc.go b/lnwire/update_fulfill_htlc.go index 6c0e6339..36977b1e 100644 --- a/lnwire/update_fulfill_htlc.go +++ b/lnwire/update_fulfill_htlc.go @@ -21,6 +21,11 @@ type UpdateFulfillHTLC struct { // PaymentPreimage is the R-value preimage required to fully settle an // HTLC. PaymentPreimage [32]byte + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData } // NewUpdateFulfillHTLC returns a new empty UpdateFulfillHTLC. @@ -47,6 +52,7 @@ func (c *UpdateFulfillHTLC) Decode(r io.Reader, pver uint32) error { &c.ChanID, &c.ID, c.PaymentPreimage[:], + &c.ExtraData, ) } @@ -59,6 +65,7 @@ func (c *UpdateFulfillHTLC) Encode(w io.Writer, pver uint32) error { c.ChanID, c.ID, c.PaymentPreimage[:], + c.ExtraData, ) } @@ -75,8 +82,7 @@ func (c *UpdateFulfillHTLC) MsgType() MessageType { // // This is part of the lnwire.Message interface. func (c *UpdateFulfillHTLC) MaxPayloadLength(uint32) uint32 { - // 32 + 8 + 32 - return 72 + return MaxMsgBody } // TargetChanID returns the channel id of the link for which this message is