diff --git a/channeldb/migration21/current/current_codec.go b/channeldb/migration21/current/current_codec.go new file mode 100644 index 00000000..0257f40a --- /dev/null +++ b/channeldb/migration21/current/current_codec.go @@ -0,0 +1,420 @@ +package current + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + // Big endian is the preferred byte order, due to cursor scans over + // integer keys iterating in order. + byteOrder = binary.BigEndian +) + +// writeOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func writeOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// readOutpoint reads an outpoint from the passed reader that was previously +// written using the writeOutpoint struct. +func readOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case keychain.KeyDescriptor: + if err := binary.Write(w, byteOrder, e.Family); err != nil { + return err + } + if err := binary.Write(w, byteOrder, e.Index); err != nil { + return err + } + + if e.PubKey != nil { + if err := binary.Write(w, byteOrder, true); err != nil { + return fmt.Errorf("error writing serialized element: %s", err) + } + + return WriteElement(w, e.PubKey) + } + + return binary.Write(w, byteOrder, false) + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wire.OutPoint: + return writeOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PrivateKey: + b := e.Serialize() + if _, err := w.Write(b); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case shachain.Producer: + return e.Encode(w) + + case shachain.Store: + return e.Encode(w) + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + var msgBuf bytes.Buffer + if _, err := lnwire.WriteMessage(&msgBuf, e, 0); err != nil { + return err + } + + msgLen := uint16(len(msgBuf.Bytes())) + if err := WriteElements(w, msgLen); err != nil { + return err + } + + if _, err := w.Write(msgBuf.Bytes()); err != nil { + return err + } + + case common.ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case lnwire.FundingFlag: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *keychain.KeyDescriptor: + if err := binary.Read(r, byteOrder, &e.Family); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &e.Index); err != nil { + return err + } + + var hasPubKey bool + if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { + return err + } + + if hasPubKey { + return ReadElement(r, &e.PubKey) + } + + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return readOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PrivateKey: + var b [btcec.PrivKeyBytesLen]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), b[:]) + *e = priv + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + case *shachain.Producer: + var root [32]byte + if _, err := io.ReadFull(r, root[:]); err != nil { + return err + } + + // TODO(roasbeef): remove + producer, err := shachain.NewRevocationProducerFromBytes(root[:]) + if err != nil { + return err + } + + *e = producer + + case *shachain.Store: + store, err := shachain.NewRevocationStoreFromBytes(r) + if err != nil { + return err + } + + *e = store + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + var msgLen uint16 + if err := ReadElement(r, &msgLen); err != nil { + return err + } + + msgReader := io.LimitReader(r, int64(msgLen)) + msg, err := lnwire.ReadMessage(msgReader, 0) + if err != nil { + return err + } + + *e = msg + + case *common.ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *lnwire.FundingFlag: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration21/current/current_encoding.go b/channeldb/migration21/current/current_encoding.go new file mode 100644 index 00000000..3aaf4415 --- /dev/null +++ b/channeldb/migration21/current/current_encoding.go @@ -0,0 +1,728 @@ +package current + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/lightningnetwork/lnd/channeldb/kvdb" + lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/channeldb/migration21/common" +) + +func serializeChanCommit(w io.Writer, c *common.ChannelCommitment) error { // nolint: dupl + if err := WriteElements(w, + c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, + c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, + c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, + c.CommitSig, + ); err != nil { + return err + } + + return serializeHtlcs(w, c.Htlcs...) +} + +func SerializeLogUpdates(w io.Writer, logUpdates []common.LogUpdate) error { // nolint: dupl + numUpdates := uint16(len(logUpdates)) + if err := binary.Write(w, byteOrder, numUpdates); err != nil { + return err + } + + for _, diff := range logUpdates { + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) + if err != nil { + return err + } + } + + return nil +} + +func serializeHtlcs(b io.Writer, htlcs ...common.HTLC) error { // nolint: dupl + numHtlcs := uint16(len(htlcs)) + if err := WriteElement(b, numHtlcs); err != nil { + return err + } + + for _, htlc := range htlcs { + if err := WriteElements(b, + htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, + htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob, + htlc.HtlcIndex, htlc.LogIndex, + ); err != nil { + return err + } + } + + return nil +} + +func SerializeCommitDiff(w io.Writer, diff *common.CommitDiff) error { // nolint: dupl + if err := serializeChanCommit(w, &diff.Commitment); err != nil { + return err + } + + if err := WriteElements(w, diff.CommitSig); err != nil { + return err + } + + if err := SerializeLogUpdates(w, diff.LogUpdates); err != nil { + return err + } + + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + + return nil +} + +func deserializeHtlcs(r io.Reader) ([]common.HTLC, error) { // nolint: dupl + var numHtlcs uint16 + if err := ReadElement(r, &numHtlcs); err != nil { + return nil, err + } + + var htlcs []common.HTLC + if numHtlcs == 0 { + return htlcs, nil + } + + htlcs = make([]common.HTLC, numHtlcs) + for i := uint16(0); i < numHtlcs; i++ { + if err := ReadElements(r, + &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, + &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, + &htlcs[i].Incoming, &htlcs[i].OnionBlob, + &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, + ); err != nil { + return htlcs, err + } + } + + return htlcs, nil +} + +func deserializeChanCommit(r io.Reader) (common.ChannelCommitment, error) { // nolint: dupl + var c common.ChannelCommitment + + err := ReadElements(r, + &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, + &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, + &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, + ) + if err != nil { + return c, err + } + + c.Htlcs, err = deserializeHtlcs(r) + if err != nil { + return c, err + } + + return c, nil +} + +func DeserializeLogUpdates(r io.Reader) ([]common.LogUpdate, error) { // nolint: dupl + var numUpdates uint16 + if err := binary.Read(r, byteOrder, &numUpdates); err != nil { + return nil, err + } + + logUpdates := make([]common.LogUpdate, numUpdates) + for i := 0; i < int(numUpdates); i++ { + err := ReadElements(r, + &logUpdates[i].LogIndex, &logUpdates[i].UpdateMsg, + ) + if err != nil { + return nil, err + } + } + return logUpdates, nil +} + +func DeserializeCommitDiff(r io.Reader) (*common.CommitDiff, error) { // nolint: dupl + var ( + d common.CommitDiff + err error + ) + + d.Commitment, err = deserializeChanCommit(r) + if err != nil { + return nil, err + } + + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + commitSig, ok := msg.(*lnwire.CommitSig) + if !ok { + return nil, fmt.Errorf("expected lnwire.CommitSig, instead "+ + "read: %T", msg) + } + d.CommitSig = commitSig + + d.LogUpdates, err = DeserializeLogUpdates(r) + if err != nil { + return nil, err + } + + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]common.CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := ReadElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]common.CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := ReadElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + return &d, nil +} + +func SerializeNetworkResult(w io.Writer, n *common.NetworkResult) error { // nolint: dupl + return WriteElements(w, n.Msg, n.Unencrypted, n.IsResolution) +} + +func DeserializeNetworkResult(r io.Reader) (*common.NetworkResult, error) { // nolint: dupl + n := &common.NetworkResult{} + + if err := ReadElements(r, + &n.Msg, &n.Unencrypted, &n.IsResolution, + ); err != nil { + return nil, err + } + + return n, nil +} + +func writeChanConfig(b io.Writer, c *common.ChannelConfig) error { // nolint: dupl + return WriteElements(b, + c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, + c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, + c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, + c.HtlcBasePoint, + ) +} + +func SerializeChannelCloseSummary(w io.Writer, cs *common.ChannelCloseSummary) error { // nolint: dupl + err := WriteElements(w, + cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, + cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, + cs.TimeLockedBalance, cs.CloseType, cs.IsPending, + ) + if err != nil { + return err + } + + // If this is a close channel summary created before the addition of + // the new fields, then we can exit here. + if cs.RemoteCurrentRevocation == nil { + return WriteElements(w, false) + } + + // If fields are present, write boolean to indicate this, and continue. + if err := WriteElements(w, true); err != nil { + return err + } + + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { + return err + } + + if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil { + return err + } + + // The RemoteNextRevocation field is optional, as it's possible for a + // channel to be closed before we learn of the next unrevoked + // revocation point for the remote party. Write a boolen indicating + // whether this field is present or not. + if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { + return err + } + + // Write the field, if present. + if cs.RemoteNextRevocation != nil { + if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { + return err + } + } + + // Write whether the channel sync message is present. + if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { + return err + } + + // Write the channel sync message, if present. + if cs.LastChanSyncMsg != nil { + if err := WriteElements(w, cs.LastChanSyncMsg); err != nil { + return err + } + } + + return nil +} + +func readChanConfig(b io.Reader, c *common.ChannelConfig) error { + return ReadElements(b, + &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, + &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, + &c.MultiSigKey, &c.RevocationBasePoint, + &c.PaymentBasePoint, &c.DelayBasePoint, + &c.HtlcBasePoint, + ) +} + +func DeserializeCloseChannelSummary(r io.Reader) (*common.ChannelCloseSummary, error) { // nolint: dupl + c := &common.ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + var hasNewFields bool + err = ReadElements(r, &hasNewFields) + if err != nil { + return nil, err + } + + // If fields are not present, we can return. + if !hasNewFields { + return c, nil + } + + // Otherwise read the new fields. + if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message then this might not be present. A boolean + // indicating whether the field is present will come first. + var hasRemoteNextRevocation bool + err = ReadElements(r, &hasRemoteNextRevocation) + if err != nil { + return nil, err + } + + // If this field was written, read it. + if hasRemoteNextRevocation { + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil { + return nil, err + } + } + + // Check if we have a channel sync message to read. + var hasChanSyncMsg bool + err = ReadElements(r, &hasChanSyncMsg) + if err == io.EOF { + return c, nil + } else if err != nil { + return nil, err + } + + // If a chan sync message is present, read it. + if hasChanSyncMsg { + // We must pass in reference to a lnwire.Message for the codec + // to support it. + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + + chanSync, ok := msg.(*lnwire.ChannelReestablish) + if !ok { + return nil, errors.New("unable cast db Message to " + + "ChannelReestablish") + } + c.LastChanSyncMsg = chanSync + } + + return c, nil +} + +// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding +// package has potentially been mangled. +var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") + +var ( + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + fwdPackagesKey = []byte("fwd-packages") + + // addBucketKey is the bucket to which all Add log updates are written. + addBucketKey = []byte("add-updates") + + // failSettleBucketKey is the bucket to which all Settle/Fail log + // updates are written. + failSettleBucketKey = []byte("fail-settle-updates") + + // fwdFilterKey is a key used to write the set of Adds that passed + // validation and are to be forwarded to the switch. + // NOTE: The presence of this key within a forwarding package indicates + // that the package has reached FwdStateProcessed. + fwdFilterKey = []byte("fwd-filter-key") + + // ackFilterKey is a key used to access the PkgFilter indicating which + // Adds have received a Settle/Fail. This response may come from a + // number of sources, including: exitHop settle/fails, switch failures, + // chain arbiter interjections, as well as settle/fails from the + // next hop in the route. + ackFilterKey = []byte("ack-filter-key") + + // settleFailFilterKey is a key used to access the PkgFilter indicating + // which Settles/Fails in have been received and processed by the link + // that originally received the Add. + settleFailFilterKey = []byte("settle-fail-filter-key") +) + +func makeLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} + +// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. +func uint16Key(i uint16) []byte { + key := make([]byte, 2) + byteOrder.PutUint16(key, i) + return key +} + +// ChannelPackager is used by a channel to manage the lifecycle of its forwarding +// packages. The packager is tied to a particular source channel ID, allowing it +// to create and edit its own packages. Each packager also has the ability to +// remove fail/settle htlcs that correspond to an add contained in one of +// source's packages. +type ChannelPackager struct { + source lnwire.ShortChannelID +} + +// NewChannelPackager creates a new packager for a single channel. +func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { + return &ChannelPackager{ + source: source, + } +} + +// AddFwdPkg writes a newly locked in forwarding package to disk. +func (*ChannelPackager) AddFwdPkg(tx kvdb.RwTx, fwdPkg *common.FwdPkg) error { // nolint: dupl + fwdPkgBkt, err := tx.CreateTopLevelBucket(fwdPackagesKey) + if err != nil { + return err + } + + source := makeLogKey(fwdPkg.Source.ToUint64()) + sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) + if err != nil { + return err + } + + heightKey := makeLogKey(fwdPkg.Height) + heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) + if err != nil { + return err + } + + // Write ADD updates we received at this commit height. + addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) + if err != nil { + return err + } + + // Write SETTLE/FAIL updates we received at this commit height. + failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) + if err != nil { + return err + } + + for i := range fwdPkg.Adds { + err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) + if err != nil { + return err + } + } + + // Persist the initialized pkg filter, which will be used to determine + // when we can remove this forwarding package from disk. + var ackFilterBuf bytes.Buffer + if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { + return err + } + + for i := range fwdPkg.SettleFails { + err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) + if err != nil { + return err + } + } + + var settleFailFilterBuf bytes.Buffer + err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) + if err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. +func putLogUpdate(bkt kvdb.RwBucket, idx uint16, htlc *common.LogUpdate) error { + var b bytes.Buffer + if err := serializeLogUpdate(&b, htlc); err != nil { + return err + } + + return bkt.Put(uint16Key(idx), b.Bytes()) +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in a map indexed by the +// remote commitment height at which the updates were locked in. +func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*common.FwdPkg, error) { + return loadChannelFwdPkgs(tx, p.source) +} + +// loadChannelFwdPkgs loads all forwarding packages owned by `source`. +func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { // nolint: dupl + fwdPkgBkt := tx.ReadBucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil, nil + } + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, nil + } + + var heights []uint64 + if err := sourceBkt.ForEach(func(k, _ []byte) error { + if len(k) != 8 { + return ErrCorruptedFwdPkg + } + + heights = append(heights, byteOrder.Uint64(k)) + + return nil + }); err != nil { + return nil, err + } + + // Load the forwarding package for each retrieved height. + fwdPkgs := make([]*common.FwdPkg, 0, len(heights)) + for _, height := range heights { + fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) + if err != nil { + return nil, err + } + + fwdPkgs = append(fwdPkgs, fwdPkg) + } + + return fwdPkgs, nil +} + +// loadFwdPkg reads the packager's fwd pkg at a given height, and determines the +// appropriate FwdState. +func loadFwdPkg(fwdPkgBkt kvdb.RBucket, source lnwire.ShortChannelID, + height uint64) (*common.FwdPkg, error) { + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.NestedReadBucket(sourceKey[:]) + if sourceBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.NestedReadBucket(heightKey[:]) + if heightBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + // Load ADDs from disk. + addBkt := heightBkt.NestedReadBucket(addBucketKey) + if addBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + adds, err := loadHtlcs(addBkt) + if err != nil { + return nil, err + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + ackFilterReader := bytes.NewReader(ackFilterBytes) + + ackFilter := &common.PkgFilter{} + if err := ackFilter.Decode(ackFilterReader); err != nil { + return nil, err + } + + // Load SETTLE/FAILs from disk. + failSettleBkt := heightBkt.NestedReadBucket(failSettleBucketKey) + if failSettleBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + failSettles, err := loadHtlcs(failSettleBkt) + if err != nil { + return nil, err + } + + // Load settle fail filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + + settleFailFilter := &common.PkgFilter{} + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return nil, err + } + + // Initialize the fwding package, which always starts in the + // FwdStateLockedIn. We can determine what state the package was left in + // by examining constraints on the information loaded from disk. + fwdPkg := &common.FwdPkg{ + Source: source, + State: common.FwdStateLockedIn, + Height: height, + Adds: adds, + AckFilter: ackFilter, + SettleFails: failSettles, + SettleFailFilter: settleFailFilter, + } + + // Check to see if we have written the set exported filter adds to + // disk. If we haven't, processing of this package was never started, or + // failed during the last attempt. + fwdFilterBytes := heightBkt.Get(fwdFilterKey) + if fwdFilterBytes == nil { + nAdds := uint16(len(adds)) + fwdPkg.FwdFilter = common.NewPkgFilter(nAdds) + return fwdPkg, nil + } + + fwdFilterReader := bytes.NewReader(fwdFilterBytes) + fwdPkg.FwdFilter = &common.PkgFilter{} + if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { + return nil, err + } + + // Otherwise, a complete round of processing was completed, and we + // advance the package to FwdStateProcessed. + fwdPkg.State = common.FwdStateProcessed + + // If every add, settle, and fail has been fully acknowledged, we can + // safely set the package's state to FwdStateCompleted, signalling that + // it can be garbage collected. + if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { + fwdPkg.State = common.FwdStateCompleted + } + + return fwdPkg, nil +} + +// loadHtlcs retrieves all serialized htlcs in a bucket, returning +// them in order of the indexes they were written under. +func loadHtlcs(bkt kvdb.RBucket) ([]common.LogUpdate, error) { + var htlcs []common.LogUpdate + if err := bkt.ForEach(func(_, v []byte) error { + htlc, err := deserializeLogUpdate(bytes.NewReader(v)) + if err != nil { + return err + } + + htlcs = append(htlcs, *htlc) + + return nil + }); err != nil { + return nil, err + } + + return htlcs, nil +} + +// serializeLogUpdate writes a log update to the provided io.Writer. +func serializeLogUpdate(w io.Writer, l *common.LogUpdate) error { + return WriteElements(w, l.LogIndex, l.UpdateMsg) +} + +// deserializeLogUpdate reads a log update from the provided io.Reader. +func deserializeLogUpdate(r io.Reader) (*common.LogUpdate, error) { + l := &common.LogUpdate{} + if err := ReadElements(r, &l.LogIndex, &l.UpdateMsg); err != nil { + return nil, err + } + + return l, nil +} diff --git a/channeldb/migration21/legacy/legacy_decoding.go b/channeldb/migration21/legacy/legacy_decoding.go index 96954cb3..923cceda 100644 --- a/channeldb/migration21/legacy/legacy_decoding.go +++ b/channeldb/migration21/legacy/legacy_decoding.go @@ -556,7 +556,7 @@ func (p *ChannelPackager) LoadFwdPkgs(tx kvdb.RTx) ([]*common.FwdPkg, error) { } // loadChannelFwdPkgs loads all forwarding packages owned by `source`. -func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { +func loadChannelFwdPkgs(tx kvdb.RTx, source lnwire.ShortChannelID) ([]*common.FwdPkg, error) { // nolint: dupl fwdPkgBkt := tx.ReadBucket(fwdPackagesKey) if fwdPkgBkt == nil { return nil, nil