diff --git a/htlcswitch/circuit.go b/htlcswitch/circuit.go index 06bce167..8db22deb 100644 --- a/htlcswitch/circuit.go +++ b/htlcswitch/circuit.go @@ -1,184 +1,229 @@ package htlcswitch import ( - "fmt" - "sync" + "encoding/binary" + "io" - "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" ) -// PaymentCircuit is used by the HTLC switch subsystem to determine the -// backwards path for the settle/fail HTLC messages. A payment circuit -// will be created once a channel link forwards the HTLC add request and -// removed when we receive a settle/fail HTLC message. +// EmptyCircuitKey is a default value for an outgoing circuit key returned when +// a circuit's keystone has not been set. Note that this value is invalid for +// use as a keystone, since the outgoing channel id can never be equal to +// sourceHop. +var EmptyCircuitKey CircuitKey + +// CircuitKey is a tuple of channel ID and HTLC ID, used to uniquely identify +// HTLCs in a circuit. Circuits are identified primarily by the circuit key of +// the incoming HTLC. However, a circuit may also be referenced by its outgoing +// circuit key after the HTLC has been forwarded via the outgoing link. +type CircuitKey = channeldb.CircuitKey + +// PaymentCircuit is used by the switch as placeholder between when the +// switch makes a forwarding decision and the outgoing link determines the +// proper HTLC ID for the local log. After the outgoing HTLC ID has been +// determined, the half circuit will be converted into a full PaymentCircuit. type PaymentCircuit struct { + // AddRef is the forward reference of the Add update in the incoming + // link's forwarding package. This value is set on the htlcPacket of the + // returned settle/fail so that it can be removed from disk. + AddRef channeldb.AddRef + + // Incoming is the circuit key identifying the incoming channel and htlc + // index from which this ADD originates. + Incoming CircuitKey + + // Outgoing is the circuit key identifying the outgoing channel, and the + // HTLC index that was used to forward the ADD. It will be nil if this + // circuit's keystone has not been set. + Outgoing *CircuitKey + // PaymentHash used as unique identifier of payment. PaymentHash [32]byte - // IncomingChanID identifies the channel from which add HTLC request - // came and to which settle/fail HTLC request will be returned back. - // Once the switch forwards the settle/fail message to the src the - // circuit is considered to be completed. - IncomingChanID lnwire.ShortChannelID + // IncomingAmount is the value of the HTLC from the incoming link. + IncomingAmount lnwire.MilliSatoshi - // IncomingHTLCID is the ID in the update_add_htlc message we received - // from the incoming channel, which will be included in any settle/fail - // messages we send back. - IncomingHTLCID uint64 - - // IncomingAmt is the value of the incoming HTLC. If we take this and - // subtract it from the OutgoingAmt, then we'll compute the total fee - // attached to this payment circuit. - IncomingAmt lnwire.MilliSatoshi - - // OutgoingChanID identifies the channel to which we propagate the HTLC - // add update and from which we are expecting to receive HTLC - // settle/fail request back. - OutgoingChanID lnwire.ShortChannelID - - // OutgoingHTLCID is the ID in the update_add_htlc message we sent to - // the outgoing channel. - OutgoingHTLCID uint64 - - // OutgoingAmt is the value of the outgoing HTLC. If we subtract this - // from the IncomingAmt, then we'll compute the total fee attached to - // this payment circuit. - OutgoingAmt lnwire.MilliSatoshi + // OutgoingAmount specifies the value of the HTLC leaving the switch, + // either as a payment or forwarded amount. + OutgoingAmount lnwire.MilliSatoshi // ErrorEncrypter is used to re-encrypt the onion failure before // sending it back to the originator of the payment. ErrorEncrypter ErrorEncrypter + + // LoadedFromDisk is set true for any circuits loaded after the circuit + // map is reloaded from disk. + // + // NOTE: This value is determined implicitly during a restart. It is not + // persisted, and should never be set outside the circuit map. + LoadedFromDisk bool } -// circuitKey is a channel ID, HTLC ID tuple used as an identifying key for a -// payment circuit. The circuit map is keyed with the identifier for the -// outgoing HTLC -type circuitKey struct { - chanID lnwire.ShortChannelID - htlcID uint64 +// HasKeystone returns true if an outgoing link has assigned this circuit's +// outgoing circuit key. +func (c *PaymentCircuit) HasKeystone() bool { + return c.Outgoing != nil } -// String returns a string representation of the circuitKey. -func (k *circuitKey) String() string { - return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.chanID, k.htlcID) -} +// newPaymentCircuit initializes a payment circuit on the heap using the payment +// hash and an in-memory htlc packet. +func newPaymentCircuit(hash *[32]byte, pkt *htlcPacket) *PaymentCircuit { + var addRef channeldb.AddRef + if pkt.sourceRef != nil { + addRef = *pkt.sourceRef + } -// CircuitMap is a data structure that implements thread safe storage of -// circuit routing information. The switch consults a circuit map to determine -// where to forward HTLC update messages. Each circuit is stored with its -// outgoing HTLC as the primary key because, each offered HTLC has at most one -// received HTLC, but there may be multiple offered or received HTLCs with the -// same payment hash. Circuits are also indexed to provide fast lookups by -// payment hash. -// -// TODO(andrew.shvv) make it persistent -type CircuitMap struct { - mtx sync.RWMutex - circuits map[circuitKey]*PaymentCircuit - hashIndex map[[32]byte]map[PaymentCircuit]struct{} -} - -// NewCircuitMap creates a new instance of the CircuitMap. -func NewCircuitMap() *CircuitMap { - return &CircuitMap{ - circuits: make(map[circuitKey]*PaymentCircuit), - hashIndex: make(map[[32]byte]map[PaymentCircuit]struct{}), + return &PaymentCircuit{ + AddRef: addRef, + Incoming: CircuitKey{ + ChanID: pkt.incomingChanID, + HtlcID: pkt.incomingHTLCID, + }, + PaymentHash: *hash, + IncomingAmount: pkt.incomingAmount, + OutgoingAmount: pkt.amount, + ErrorEncrypter: pkt.obfuscator, } } -// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC -// IDs. Returns nil if there is no such circuit. -func (cm *CircuitMap) LookupByHTLC(chanID lnwire.ShortChannelID, htlcID uint64) *PaymentCircuit { - cm.mtx.RLock() - - key := circuitKey{ - chanID: chanID, - htlcID: htlcID, +// makePaymentCircuit initalizes a payment circuit on the stack using the +// payment hash and an in-memory htlc packet. +func makePaymentCircuit(hash *[32]byte, pkt *htlcPacket) PaymentCircuit { + var addRef channeldb.AddRef + if pkt.sourceRef != nil { + addRef = *pkt.sourceRef } - circuit := cm.circuits[key] - cm.mtx.RUnlock() - return circuit + return PaymentCircuit{ + AddRef: addRef, + Incoming: CircuitKey{ + ChanID: pkt.incomingChanID, + HtlcID: pkt.incomingHTLCID, + }, + PaymentHash: *hash, + IncomingAmount: pkt.incomingAmount, + OutgoingAmount: pkt.amount, + ErrorEncrypter: pkt.obfuscator, + } } -// LookupByPaymentHash looks up and returns any payment circuits with a given -// payment hash. -func (cm *CircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit { - cm.mtx.RLock() - - var circuits []*PaymentCircuit - if circuitSet, ok := cm.hashIndex[hash]; ok { - circuits = make([]*PaymentCircuit, 0, len(circuitSet)) - for circuit := range circuitSet { - circuits = append(circuits, &circuit) - } +// Encode writes a PaymentCircuit to the provided io.Writer. +func (c *PaymentCircuit) Encode(w io.Writer) error { + if err := c.AddRef.Encode(w); err != nil { + return err } - cm.mtx.RUnlock() - return circuits + if err := c.Incoming.Encode(w); err != nil { + return err + } + + if _, err := w.Write(c.PaymentHash[:]); err != nil { + return err + } + + var scratch [8]byte + + binary.BigEndian.PutUint64(scratch[:], uint64(c.IncomingAmount)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + binary.BigEndian.PutUint64(scratch[:], uint64(c.OutgoingAmount)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + // Defaults to EncrypterTypeNone. + var encrypterType EncrypterType + if c.ErrorEncrypter != nil { + encrypterType = c.ErrorEncrypter.Type() + } + + err := binary.Write(w, binary.BigEndian, encrypterType) + if err != nil { + return err + } + + // Skip encoding of error encrypter if this half add does not have one. + if encrypterType == EncrypterTypeNone { + return nil + } + + return c.ErrorEncrypter.Encode(w) } -// Add adds a new active payment circuit to the CircuitMap. -func (cm *CircuitMap) Add(circuit *PaymentCircuit) error { - cm.mtx.Lock() - - key := circuitKey{ - chanID: circuit.OutgoingChanID, - htlcID: circuit.OutgoingHTLCID, +// Decode reads a PaymentCircuit from the provided io.Reader. +func (c *PaymentCircuit) Decode(r io.Reader) error { + if err := c.AddRef.Decode(r); err != nil { + return err } - cm.circuits[key] = circuit - // Add circuit to the hash index. - if _, ok := cm.hashIndex[circuit.PaymentHash]; !ok { - cm.hashIndex[circuit.PaymentHash] = make(map[PaymentCircuit]struct{}) + if err := c.Incoming.Decode(r); err != nil { + return err } - cm.hashIndex[circuit.PaymentHash][*circuit] = struct{}{} - cm.mtx.Unlock() - return nil + if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil { + return err + } + + var scratch [8]byte + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return err + } + c.IncomingAmount = lnwire.MilliSatoshi( + binary.BigEndian.Uint64(scratch[:])) + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return err + } + c.OutgoingAmount = lnwire.MilliSatoshi( + binary.BigEndian.Uint64(scratch[:])) + + // Read the encrypter type used for this circuit. + var encrypterType EncrypterType + err := binary.Read(r, binary.BigEndian, &encrypterType) + if err != nil { + return err + } + + switch encrypterType { + case EncrypterTypeNone: + // No encrypter was provided, such as when the payment is + // locally initiated. + return nil + + case EncrypterTypeSphinx: + // Sphinx encrypter was used as this is a forwarded HTLC. + c.ErrorEncrypter = NewSphinxErrorEncrypter() + + case EncrypterTypeMock: + // Test encrypter. + c.ErrorEncrypter = NewMockObfuscator() + + default: + return UnknownEncrypterType(encrypterType) + } + + return c.ErrorEncrypter.Decode(r) } -// Remove destroys the target circuit by removing it from the circuit map. -func (cm *CircuitMap) Remove(chanID lnwire.ShortChannelID, htlcID uint64) error { - cm.mtx.Lock() - defer cm.mtx.Unlock() - - // Look up circuit so that pointer can be matched in the hash index. - key := circuitKey{ - chanID: chanID, - htlcID: htlcID, - } - circuit, found := cm.circuits[key] - if !found { - return errors.Errorf("Can't find circuit for HTLC %v", key) - } - delete(cm.circuits, key) - - // Remove circuit from hash index. - circuitsWithHash, ok := cm.hashIndex[circuit.PaymentHash] - if !ok { - return errors.Errorf("Can't find circuit in hash index for HTLC %v", - key) - } - - if _, ok = circuitsWithHash[*circuit]; !ok { - return errors.Errorf("Can't find circuit in hash index for HTLC %v", - key) - } - - delete(circuitsWithHash, *circuit) - if len(circuitsWithHash) == 0 { - delete(cm.hashIndex, circuit.PaymentHash) - } - return nil +// InKey returns the primary identifier for the circuit corresponding to the +// incoming HTLC. +func (c *PaymentCircuit) InKey() CircuitKey { + return c.Incoming } -// pending returns number of circuits which are waiting for to be completed -// (settle/fail responses to be received). -func (cm *CircuitMap) pending() int { - cm.mtx.RLock() - count := len(cm.circuits) - cm.mtx.RUnlock() - return count +// OutKey returns the keystone identifying the outgoing link and HTLC ID. If the +// circuit hasn't been completed, this method returns an EmptyKeystone, which is +// an invalid outgoing circuit key. Only call this method if HasKeystone returns +// true. +func (c *PaymentCircuit) OutKey() CircuitKey { + if c.Outgoing != nil { + return *c.Outgoing + } + + return EmptyCircuitKey }