diff --git a/channeldb/channel.go b/channeldb/channel.go index 0153f3d7..5a5c7f76 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -66,6 +66,12 @@ var ( // channel closure. This key should be accessed from within the // sub-bucket of a target channel, identified by its channel point. revocationLogBucket = []byte("revocation-log-key") + + // fwdPackageLogBucket is a bucket that stores the locked-in htlcs after + // having received a revocation from the remote party. The keys in this + // bucket represent the remote height at which these htlcs were + // accepted. + fwdPackageLogBucket = []byte("fwd-package-log-key") ) var ( @@ -86,6 +92,11 @@ var ( // each time we write a new state in order to be properly fault // tolerant. ErrNoPendingCommit = fmt.Errorf("no pending commits found") + + // ErrInvalidCircuitKeyLen signals that a circuit key could not be + // decoded because the byte slice is of an invalid length. + ErrInvalidCircuitKeyLen = fmt.Errorf( + "length of serialized circuit key must be 16 bytes") ) // ChannelType is an enum-like type that describes one of several possible @@ -387,6 +398,11 @@ type OpenChannel struct { // implementation of secret store is shachain store. RevocationStore shachain.Store + // Packager is used to create and update forwarding packages for this + // channel, which encodes all necessary information to recover from + // failures and reforward HTLCs that were not fully processed. + Packager FwdPackager + // TODO(roasbeef): eww Db *DB @@ -615,6 +631,8 @@ func fetchOpenChannel(chanBucket *bolt.Bucket, return nil, fmt.Errorf("unable to fetch chan revocations: %v", err) } + channel.Packager = NewChannelPackager(channel.ShortChanID) + return channel, nil } @@ -837,6 +855,84 @@ type LogUpdate struct { UpdateMsg lnwire.Message } +// Encode writes a log update to the provided io.Writer. +func (l *LogUpdate) Encode(w io.Writer) error { + return writeElements(w, l.LogIndex, l.UpdateMsg) +} + +// Decode reads a log update from the provided io.Reader. +func (l *LogUpdate) Decode(r io.Reader) error { + return readElements(r, &l.LogIndex, &l.UpdateMsg) +} + +// CircuitKey is used by a channel to uniquely identify the HTLCs it receives +// from the switch, and is used to purge our in-memory state of HTLCs that have +// already been processed by a link. Two list of CircuitKeys are included in +// each CommitDiff to allow a link to determine which in-memory htlcs directed +// the opening and closing of circuits in the switch's circuit map. +type CircuitKey struct { + // ChanID is the short chanid indicating the HTLC's origin. + // + // NOTE: It is fine for this value to be blank, as this indicates a + // locally-sourced payment. + ChanID lnwire.ShortChannelID + + // HtlcID is the unique htlc index predominately assigned by links, + // though can also be assigned by switch in the case of locally-sourced + // payments. + HtlcID uint64 +} + +// SetBytes deserializes the given bytes into this CircuitKey. +func (k *CircuitKey) SetBytes(bs []byte) error { + if len(bs) != 16 { + return ErrInvalidCircuitKeyLen + } + + k.ChanID = lnwire.NewShortChanIDFromInt( + binary.BigEndian.Uint64(bs[:8])) + k.HtlcID = binary.BigEndian.Uint64(bs[8:]) + + return nil +} + +// Bytes returns the serialized bytes for this circuit key. +func (k CircuitKey) Bytes() []byte { + var bs = make([]byte, 16) + binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64()) + binary.BigEndian.PutUint64(bs[8:], k.HtlcID) + return bs +} + +// Encode writes a CircuitKey to the provided io.Writer. +func (k *CircuitKey) Encode(w io.Writer) error { + var scratch [16]byte + binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64()) + binary.BigEndian.PutUint64(scratch[8:], k.HtlcID) + + _, err := w.Write(scratch[:]) + return err +} + +// Decode reads a CircuitKey from the provided io.Reader. +func (k *CircuitKey) Decode(r io.Reader) error { + var scratch [16]byte + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return err + } + k.ChanID = lnwire.NewShortChanIDFromInt( + binary.BigEndian.Uint64(scratch[:8])) + k.HtlcID = binary.BigEndian.Uint64(scratch[8:]) + + return nil +} + +// String returns a string representation of the CircuitKey. +func (k CircuitKey) String() string { + return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID) +} + // CommitDiff represents the delta needed to apply the state transition between // two subsequent commitment states. Given state N and state N+1, one is able // to apply the set of messages contained within the CommitDiff to N to arrive @@ -860,6 +956,36 @@ type CommitDiff struct { // within this message should properly cover the new commitment state // and also the HTLC's within the new commitment state. CommitSig *lnwire.CommitSig + + // OpenedCircuitKeys is a set of unique identifiers for any downstream + // Add packets included in this commitment txn. After a restart, this + // set of htlcs is acked from the link's incoming mailbox to ensure + // there isn't an attempt to re-add them to this commitment txn. + OpenedCircuitKeys []CircuitKey + + // ClosedCircuitKeys records the unique identifiers for any settle/fail + // packets that were resolved by this commitment txn. After a restart, + // this is used to ensure those circuits are removed from the circuit + // map, and the downstream packets in the link's mailbox are removed. + ClosedCircuitKeys []CircuitKey + + // AddAcks specifies the locations (commit height, pkg index) of any + // Adds that were failed/settled in this commit diff. This will ack + // entries in *this* channel's forwarding packages. + // + // NOTE: This value is not serialized, it is used to atomically mark the + // resolution of adds, such that they will not be reprocessed after a + // restart. + AddAcks []AddRef + + // SettleFailAcks specifies the locations (chan id, commit height, pkg + // index) of any Settles or Fails that were locked into this commit + // diff, and originate from *another* channel, i.e. the outgoing link. + // + // NOTE: This value is not serialized, it is used to atomically acks + // settles and fails from the forwarding packages of other channels, + // such that they will not be reforwarded internally after a restart. + SettleFailAcks []SettleFailRef } func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { @@ -883,8 +1009,33 @@ func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { } } + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := writeElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := writeElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + return nil } + func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { var ( d CommitDiff @@ -916,6 +1067,36 @@ func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { } } + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := readElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := readElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + return &d, nil } @@ -938,6 +1119,26 @@ func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { return err } + // Any outgoing settles and fails necessarily have a + // corresponding adds in this channel's forwarding packages. + // Mark all of these as being fully processed in our forwarding + // package, which prevents us from reprocessing them after + // startup. + err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...) + if err != nil { + return err + } + + // Additionally, we ack from any fails or settles that are + // persisted in another channel's forwarding package. This + // prevents the same fails and settles from being retransmitted + // after restarts. The actual fail or settle we need to + // propagate to the remote party is now in the commit diff. + err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...) + if err != nil { + return err + } + // TODO(roasbeef): use seqno to derive key for later LCP // With the bucket retrieved, we'll now serialize the commit @@ -1021,15 +1222,15 @@ func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { // this log can be consulted in order to reconstruct the state needed to // rectify the situation. This method will add the current commitment for the // remote party to the revocation log, and promote the current pending -// commitment to the current remove commitment. -func (c *OpenChannel) AdvanceCommitChainTail() error { +// commitment to the current remote commitment. +func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { c.Lock() defer c.Unlock() var newRemoteCommit *ChannelCommitment err := c.Db.Update(func(tx *bolt.Tx) error { - chanBucket, err := readChanBucket(tx, c.IdentityPub, + chanBucket, err := updateChanBucket(tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash) if err != nil { return err @@ -1081,7 +1282,15 @@ func (c *OpenChannel) AdvanceCommitChainTail() error { return err } + // Lastly, we write the forwarding package to disk so that we + // can properly recover from failures and reforward HTLCs that + // have not received a corresponding settle/fail. + if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil { + return err + } + newRemoteCommit = &newCommit.Commitment + return nil }) if err != nil { @@ -1096,6 +1305,40 @@ func (c *OpenChannel) AdvanceCommitChainTail() error { return nil } +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in map indexed by the +// remote commitment height at which the updates were locked in. +func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + var fwdPkgs []*FwdPkg + if err := c.Db.View(func(tx *bolt.Tx) error { + var err error + fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) + return err + }); err != nil { + return nil, err + } + + return fwdPkgs, nil +} + +// SetFwdFilter atomically sets the forwarding filter for the forwarding package +// identified by `height`. +func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + return c.Db.Update(func(tx *bolt.Tx) error { + return c.Packager.SetFwdFilter(tx, height, fwdFilter) + }) +} + +// RemoveFwdPkg atomically removes a forwarding package specified by the remote +// commitment height. +// +// NOTE: This method should only be called on packages marked FwdStateCompleted. +func (c *OpenChannel) RemoveFwdPkg(height uint64) error { + return c.Db.Update(func(tx *bolt.Tx) error { + return c.Packager.RemovePkg(tx, height) + }) +} + // RevocationLogTail returns the "tail", or the end of the current revocation // log. This entry represents the last previous state for the remote node's // commitment chain. The ChannelDelta returned by this method will always lag @@ -1671,6 +1914,8 @@ func fetchChanInfo(chanBucket *bolt.Bucket, channel *OpenChannel) error { return err } + channel.Packager = NewChannelPackager(channel.ShortChanID) + return nil }