diff --git a/htlcswitch/circuit_map.go b/htlcswitch/circuit_map.go index 01108d50..e5076544 100644 --- a/htlcswitch/circuit_map.go +++ b/htlcswitch/circuit_map.go @@ -144,11 +144,7 @@ var ( // always identifiable by their incoming CircuitKey, in addition to their // outgoing CircuitKey if the circuit is fully-opened. type circuitMap struct { - // db provides the persistent storage engine for the circuit map. - // - // TODO(conner): create abstraction to allow for the substitution of - // other persistence engines. - db *channeldb.DB + cfg *CircuitMapConfig mtx sync.RWMutex @@ -172,10 +168,23 @@ type circuitMap struct { hashIndex map[[32]byte]map[CircuitKey]struct{} } +// CircuitMapConfig houses the critical interfaces and references necessary to +// parameterize an instance of circuitMap. +type CircuitMapConfig struct { + // DB provides the persistent storage engine for the circuit map. + // TODO(conner): create abstraction to allow for the substitution of + // other persistence engines. + DB *channeldb.DB + + // ExtractErrorEncrypter derives the shared secret used to encrypt + // errors from the obfuscator's ephemeral public key. + ExtractErrorEncrypter ErrorEncrypterExtracter +} + // NewCircuitMap creates a new instance of the circuitMap. -func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) { +func NewCircuitMap(cfg *CircuitMapConfig) (CircuitMap, error) { cm := &circuitMap{ - db: db, + cfg: cfg, } // Initialize the on-disk buckets used by the circuit map. @@ -203,7 +212,7 @@ func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) { // initBuckets ensures that the primary buckets used by the circuit are // initialized so that we can assume their existence after startup. func (cm *circuitMap) initBuckets() error { - return cm.db.Update(func(tx *bolt.Tx) error { + return cm.cfg.DB.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucketIfNotExists(circuitKeystoneKey) if err != nil { return err @@ -226,7 +235,7 @@ func (cm *circuitMap) restoreMemState() error { pending = make(map[CircuitKey]*PaymentCircuit) ) - if err := cm.db.View(func(tx *bolt.Tx) error { + if err := cm.cfg.DB.View(func(tx *bolt.Tx) error { // Restore any of the circuits persisted in the circuit bucket // back into memory. circuitBkt := tx.Bucket(circuitAddKey) @@ -235,7 +244,7 @@ func (cm *circuitMap) restoreMemState() error { } if err := circuitBkt.ForEach(func(_, v []byte) error { - circuit, err := decodeCircuit(v) + circuit, err := cm.decodeCircuit(v) if err != nil { return err } @@ -305,8 +314,9 @@ func (cm *circuitMap) restoreMemState() error { // decodeCircuit reconstructs an in-memory payment circuit from a byte slice. // The byte slice is assumed to have been generated by the circuit's Encode -// method. -func decodeCircuit(v []byte) (*PaymentCircuit, error) { +// method. If the decoding is successful, the onion obfuscator will be +// reextracted, since it is not stored in plaintext on disk. +func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) { var circuit = &PaymentCircuit{} circuitReader := bytes.NewReader(v) @@ -314,6 +324,21 @@ func decodeCircuit(v []byte) (*PaymentCircuit, error) { return nil, err } + // If the error encrypter is nil, this is locally-source payment so + // there is no encrypter. + if circuit.ErrorEncrypter == nil { + return circuit, nil + } + + // Otherwise, we need to reextract the encrypter, so that the shared + // secret is rederived from what was decoded. + err := circuit.ErrorEncrypter.Reextract( + cm.cfg.ExtractErrorEncrypter, + ) + if err != nil { + return nil, err + } + return circuit, nil } @@ -325,7 +350,7 @@ func decodeCircuit(v []byte) (*PaymentCircuit, error) { // channels. Therefore, it must be called before any links are created to avoid // interfering with normal operation. func (cm *circuitMap) trimAllOpenCircuits() error { - activeChannels, err := cm.db.FetchAllChannels() + activeChannels, err := cm.cfg.DB.FetchAllChannels() if err != nil { return err } @@ -385,7 +410,7 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID, return nil } - return cm.db.Update(func(tx *bolt.Tx) error { + return cm.cfg.DB.Update(func(tx *bolt.Tx) error { keystoneBkt := tx.Bucket(circuitKeystoneKey) if keystoneBkt == nil { return ErrCorruptedCircuitMap @@ -533,7 +558,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) ( // Write the entire batch of circuits to the persistent circuit bucket // using bolt's Batch write. This method must be called from multiple, // distinct goroutines to have any impact on performance. - err := cm.db.Batch(func(tx *bolt.Tx) error { + err := cm.cfg.DB.Batch(func(tx *bolt.Tx) error { circuitBkt := tx.Bucket(circuitAddKey) if circuitBkt == nil { return ErrCorruptedCircuitMap @@ -623,7 +648,7 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error { } cm.mtx.RUnlock() - err := cm.db.Update(func(tx *bolt.Tx) error { + err := cm.cfg.DB.Update(func(tx *bolt.Tx) error { // Now, load the circuit bucket to which we will write the // already serialized circuit. keystoneBkt := tx.Bucket(circuitKeystoneKey) @@ -769,7 +794,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error { } cm.mtx.Unlock() - err := cm.db.Batch(func(tx *bolt.Tx) error { + err := cm.cfg.DB.Batch(func(tx *bolt.Tx) error { for _, circuit := range removedCircuits { // If this htlc made it to an outgoing link, load the // keystone bucket from which we will remove the