Merge pull request #826 from cfromknecht/reextract-circuit-encrypters

Reextract Circuit Error Encrypters
This commit is contained in:
Olaoluwa Osuntokun 2018-03-13 18:10:12 -07:00 committed by GitHub
commit 0fb7804e4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 320 additions and 137 deletions

@ -144,11 +144,7 @@ var (
// always identifiable by their incoming CircuitKey, in addition to their // always identifiable by their incoming CircuitKey, in addition to their
// outgoing CircuitKey if the circuit is fully-opened. // outgoing CircuitKey if the circuit is fully-opened.
type circuitMap struct { type circuitMap struct {
// db provides the persistent storage engine for the circuit map. cfg *CircuitMapConfig
//
// TODO(conner): create abstraction to allow for the substitution of
// other persistence engines.
db *channeldb.DB
mtx sync.RWMutex mtx sync.RWMutex
@ -172,10 +168,23 @@ type circuitMap struct {
hashIndex map[[32]byte]map[CircuitKey]struct{} hashIndex map[[32]byte]map[CircuitKey]struct{}
} }
// CircuitMapConfig houses the critical interfaces and references necessary to
// parameterize an instance of circuitMap.
type CircuitMapConfig struct {
// DB provides the persistent storage engine for the circuit map.
// TODO(conner): create abstraction to allow for the substitution of
// other persistence engines.
DB *channeldb.DB
// ExtractErrorEncrypter derives the shared secret used to encrypt
// errors from the obfuscator's ephemeral public key.
ExtractErrorEncrypter ErrorEncrypterExtracter
}
// NewCircuitMap creates a new instance of the circuitMap. // NewCircuitMap creates a new instance of the circuitMap.
func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) { func NewCircuitMap(cfg *CircuitMapConfig) (CircuitMap, error) {
cm := &circuitMap{ cm := &circuitMap{
db: db, cfg: cfg,
} }
// Initialize the on-disk buckets used by the circuit map. // Initialize the on-disk buckets used by the circuit map.
@ -203,7 +212,7 @@ func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) {
// initBuckets ensures that the primary buckets used by the circuit are // initBuckets ensures that the primary buckets used by the circuit are
// initialized so that we can assume their existence after startup. // initialized so that we can assume their existence after startup.
func (cm *circuitMap) initBuckets() error { func (cm *circuitMap) initBuckets() error {
return cm.db.Update(func(tx *bolt.Tx) error { return cm.cfg.DB.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(circuitKeystoneKey) _, err := tx.CreateBucketIfNotExists(circuitKeystoneKey)
if err != nil { if err != nil {
return err return err
@ -226,7 +235,7 @@ func (cm *circuitMap) restoreMemState() error {
pending = make(map[CircuitKey]*PaymentCircuit) pending = make(map[CircuitKey]*PaymentCircuit)
) )
if err := cm.db.View(func(tx *bolt.Tx) error { if err := cm.cfg.DB.View(func(tx *bolt.Tx) error {
// Restore any of the circuits persisted in the circuit bucket // Restore any of the circuits persisted in the circuit bucket
// back into memory. // back into memory.
circuitBkt := tx.Bucket(circuitAddKey) circuitBkt := tx.Bucket(circuitAddKey)
@ -235,7 +244,7 @@ func (cm *circuitMap) restoreMemState() error {
} }
if err := circuitBkt.ForEach(func(_, v []byte) error { if err := circuitBkt.ForEach(func(_, v []byte) error {
circuit, err := decodeCircuit(v) circuit, err := cm.decodeCircuit(v)
if err != nil { if err != nil {
return err return err
} }
@ -305,8 +314,9 @@ func (cm *circuitMap) restoreMemState() error {
// decodeCircuit reconstructs an in-memory payment circuit from a byte slice. // decodeCircuit reconstructs an in-memory payment circuit from a byte slice.
// The byte slice is assumed to have been generated by the circuit's Encode // The byte slice is assumed to have been generated by the circuit's Encode
// method. // method. If the decoding is successful, the onion obfuscator will be
func decodeCircuit(v []byte) (*PaymentCircuit, error) { // reextracted, since it is not stored in plaintext on disk.
func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) {
var circuit = &PaymentCircuit{} var circuit = &PaymentCircuit{}
circuitReader := bytes.NewReader(v) circuitReader := bytes.NewReader(v)
@ -314,6 +324,21 @@ func decodeCircuit(v []byte) (*PaymentCircuit, error) {
return nil, err return nil, err
} }
// If the error encrypter is nil, this is locally-source payment so
// there is no encrypter.
if circuit.ErrorEncrypter == nil {
return circuit, nil
}
// Otherwise, we need to reextract the encrypter, so that the shared
// secret is rederived from what was decoded.
err := circuit.ErrorEncrypter.Reextract(
cm.cfg.ExtractErrorEncrypter,
)
if err != nil {
return nil, err
}
return circuit, nil return circuit, nil
} }
@ -325,7 +350,7 @@ func decodeCircuit(v []byte) (*PaymentCircuit, error) {
// channels. Therefore, it must be called before any links are created to avoid // channels. Therefore, it must be called before any links are created to avoid
// interfering with normal operation. // interfering with normal operation.
func (cm *circuitMap) trimAllOpenCircuits() error { func (cm *circuitMap) trimAllOpenCircuits() error {
activeChannels, err := cm.db.FetchAllChannels() activeChannels, err := cm.cfg.DB.FetchAllChannels()
if err != nil { if err != nil {
return err return err
} }
@ -385,7 +410,7 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
return nil return nil
} }
return cm.db.Update(func(tx *bolt.Tx) error { return cm.cfg.DB.Update(func(tx *bolt.Tx) error {
keystoneBkt := tx.Bucket(circuitKeystoneKey) keystoneBkt := tx.Bucket(circuitKeystoneKey)
if keystoneBkt == nil { if keystoneBkt == nil {
return ErrCorruptedCircuitMap return ErrCorruptedCircuitMap
@ -533,7 +558,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
// Write the entire batch of circuits to the persistent circuit bucket // Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple, // using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance. // distinct goroutines to have any impact on performance.
err := cm.db.Batch(func(tx *bolt.Tx) error { err := cm.cfg.DB.Batch(func(tx *bolt.Tx) error {
circuitBkt := tx.Bucket(circuitAddKey) circuitBkt := tx.Bucket(circuitAddKey)
if circuitBkt == nil { if circuitBkt == nil {
return ErrCorruptedCircuitMap return ErrCorruptedCircuitMap
@ -623,7 +648,7 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
} }
cm.mtx.RUnlock() cm.mtx.RUnlock()
err := cm.db.Update(func(tx *bolt.Tx) error { err := cm.cfg.DB.Update(func(tx *bolt.Tx) error {
// Now, load the circuit bucket to which we will write the // Now, load the circuit bucket to which we will write the
// already serialized circuit. // already serialized circuit.
keystoneBkt := tx.Bucket(circuitKeystoneKey) keystoneBkt := tx.Bucket(circuitKeystoneKey)
@ -769,7 +794,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
} }
cm.mtx.Unlock() cm.mtx.Unlock()
err := cm.db.Batch(func(tx *bolt.Tx) error { err := cm.cfg.DB.Batch(func(tx *bolt.Tx) error {
for _, circuit := range removedCircuits { for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the // If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the // keystone bucket from which we will remove the

@ -6,9 +6,12 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
bitcoinCfg "github.com/roasbeef/btcd/chaincfg"
"github.com/roasbeef/btcutil" "github.com/roasbeef/btcutil"
) )
@ -16,19 +19,114 @@ var (
hash1 = [32]byte{0x01} hash1 = [32]byte{0x01}
hash2 = [32]byte{0x02} hash2 = [32]byte{0x02}
hash3 = [32]byte{0x03} hash3 = [32]byte{0x03}
// sphinxPrivKey is the private key given to freshly created sphinx
// routers.
sphinxPrivKey *btcec.PrivateKey
// testEphemeralKey is the ephemeral key that will be extracted to
// create onion obfuscators.
testEphemeralKey *btcec.PublicKey
// testExtracter is a precomputed extraction of testEphemeralKey, using
// the sphinxPrivKey.
testExtracter *htlcswitch.SphinxErrorEncrypter
) )
func TestCircuitMapInit(t *testing.T) { func init() {
t.Parallel() // Generate a fresh key for our sphinx router.
var err error
sphinxPrivKey, err = btcec.NewPrivateKey(btcec.S256())
if err != nil {
panic(err)
}
// Initialize new database for circuit map. // And another, whose public key will serve as the test ephemeral key.
cdb := makeCircuitDB(t, "") testEphemeralPriv, err := btcec.NewPrivateKey(btcec.S256())
_, err := htlcswitch.NewCircuitMap(cdb) if err != nil {
panic(err)
}
testEphemeralKey = testEphemeralPriv.PubKey()
// Finally, properly initialize the test extracter
initTestExtracter()
}
// initTestExtracter spins up a new onion processor specifically for the purpose
// of generating our testExtracter, which should be derived from the
// testEphemeralKey, and which randomly-generated key is used to init the sphinx
// router.
//
// NOTE: This should be called in init(), after testEphemeralKey has been
// properly initialized.
func initTestExtracter() {
onionProcessor := newOnionProcessor(nil)
defer onionProcessor.Stop()
obfuscator, _ := onionProcessor.ExtractErrorEncrypter(
testEphemeralKey,
)
sphinxExtracter, ok := obfuscator.(*htlcswitch.SphinxErrorEncrypter)
if !ok {
panic("did not extract sphinx error encrypter")
}
testExtracter = sphinxExtracter
// We also set this error extracter on startup, otherwise it will be nil
// at compile-time.
halfCircuitTests[2].encrypter = testExtracter
}
// newOnionProcessor creates starts a new htlcswitch.OnionProcessor using a temp
// db and no garbage collection.
func newOnionProcessor(t *testing.T) *htlcswitch.OnionProcessor {
sharedSecretFile, err := ioutil.TempFile("", "sphinxreplay.db")
if err != nil {
t.Fatalf("unable to create temp path: %v", err)
}
sharedSecretPath := sharedSecretFile.Name()
sphinxRouter := sphinx.NewRouter(
sharedSecretPath, sphinxPrivKey, &bitcoinCfg.SimNetParams, nil,
)
if err := sphinxRouter.Start(); err != nil {
t.Fatalf("unable to start sphinx router: %v", err)
}
return htlcswitch.NewOnionProcessor(sphinxRouter)
}
// newCircuitMap creates a new htlcswitch.CircuitMap using a temp db and a
// fresh sphinx router.
func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
htlcswitch.CircuitMap) {
onionProcessor := newOnionProcessor(t)
circuitMapCfg := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, ""),
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
}
circuitMap, err := htlcswitch.NewCircuitMap(circuitMapCfg)
if err != nil { if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err) t.Fatalf("unable to create persistent circuit map: %v", err)
} }
restartCircuitMap(t, cdb) return circuitMapCfg, circuitMap
}
// TestCircuitMapInit is a quick check to ensure that we can start and restore
// the circuit map, as this will be used extensively in this suite.
func TestCircuitMapInit(t *testing.T) {
t.Parallel()
cfg, _ := newCircuitMap(t)
restartCircuitMap(t, cfg)
} }
var halfCircuitTests = []struct { var halfCircuitTests = []struct {
@ -56,12 +154,15 @@ var halfCircuitTests = []struct {
encrypter: htlcswitch.NewMockObfuscator(), encrypter: htlcswitch.NewMockObfuscator(),
}, },
{ {
hash: hash3, hash: hash3,
inValue: 10000, inValue: 10000,
outValue: 9000, outValue: 9000,
chanID: lnwire.NewShortChanIDFromInt(3), chanID: lnwire.NewShortChanIDFromInt(3),
htlcID: 3, htlcID: 3,
encrypter: htlcswitch.NewSphinxErrorEncrypter(), // NOTE: The value of testExtracter is nil at compile-time, it
// is fully-initialized in initTestExtracter, which should
// repopulate this encrypter.
encrypter: testExtracter,
}, },
} }
@ -72,6 +173,8 @@ var halfCircuitTests = []struct {
func TestHalfCircuitSerialization(t *testing.T) { func TestHalfCircuitSerialization(t *testing.T) {
t.Parallel() t.Parallel()
onionProcessor := newOnionProcessor(t)
for i, test := range halfCircuitTests { for i, test := range halfCircuitTests {
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
PaymentHash: test.hash, PaymentHash: test.hash,
@ -97,6 +200,20 @@ func TestHalfCircuitSerialization(t *testing.T) {
t.Fatalf("unable to decode half payment circuit test=%d: %v", i, err) t.Fatalf("unable to decode half payment circuit test=%d: %v", i, err)
} }
// If the error encrypter is initialized, we will need to
// reextract it from it's decoded state, as this requires an
// ECDH with the onion processor's private key. For mock error
// encrypters, this will be a NOP.
if circuit2.ErrorEncrypter != nil {
err := circuit2.ErrorEncrypter.Reextract(
onionProcessor.ExtractErrorEncrypter,
)
if err != nil {
t.Fatalf("unable to reextract sphinx error "+
"encrypter: %v", err)
}
}
// Reconstructed half circuit should match the original. // Reconstructed half circuit should match the original.
if !equalIgnoreLFD(circuit, &circuit2) { if !equalIgnoreLFD(circuit, &circuit2) {
t.Fatalf("unexpected half circuit test=%d, want %v, got %v", t.Fatalf("unexpected half circuit test=%d, want %v, got %v",
@ -115,11 +232,7 @@ func TestCircuitMapPersistence(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{chan1, 0}) circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{chan1, 0})
if circuit != nil { if circuit != nil {
@ -143,7 +256,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash1, 0) assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertHasCircuit(t, circuitMap, circuit1) assertHasCircuit(t, circuitMap, circuit1)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 0) assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertHasCircuit(t, circuitMap, circuit1) assertHasCircuit(t, circuitMap, circuit1)
@ -168,7 +281,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertHasCircuit(t, circuitMap, circuit1) assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1) assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1) assertHasCircuit(t, circuitMap, circuit1)
@ -213,7 +326,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash3, 0) assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1) assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1) assertHasCircuit(t, circuitMap, circuit1)
@ -238,7 +351,7 @@ func TestCircuitMapPersistence(t *testing.T) {
} }
assertHasCircuit(t, circuitMap, circuit3) assertHasCircuit(t, circuitMap, circuit3)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertHasCircuit(t, circuitMap, circuit3) assertHasCircuit(t, circuitMap, circuit3)
// Add another circuit with an already-used HTLC ID but different // Add another circuit with an already-used HTLC ID but different
@ -260,7 +373,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3) assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3) assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
@ -294,7 +407,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash1, 1) assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit4) assertHasCircuit(t, circuitMap, circuit4)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1) assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit4) assertHasCircuit(t, circuitMap, circuit4)
@ -335,7 +448,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertHasCircuitForHash(t, circuitMap, hash3, circuit3) assertHasCircuitForHash(t, circuitMap, hash3, circuit3)
// Restart, then run checks again. // Restart, then run checks again.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
// Verify that all circuits have been fully added. // Verify that all circuits have been fully added.
assertHasCircuit(t, circuitMap, circuit1) assertHasCircuit(t, circuitMap, circuit1)
@ -368,7 +481,7 @@ func TestCircuitMapPersistence(t *testing.T) {
// should be circuit4. // should be circuit4.
assertNumCircuitsWithHash(t, circuitMap, hash1, 1) assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4) assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1) assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4) assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
@ -391,7 +504,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash1, 0) assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1) assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1) assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 0) assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1) assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1) assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
@ -405,7 +518,7 @@ func TestCircuitMapPersistence(t *testing.T) {
// There should now only be one remaining circuit, with hash3. // There should now only be one remaining circuit, with hash3.
assertNumCircuitsWithHash(t, circuitMap, hash2, 0) assertNumCircuitsWithHash(t, circuitMap, hash2, 0)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1) assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash2, 0) assertNumCircuitsWithHash(t, circuitMap, hash2, 0)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1) assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
@ -417,7 +530,7 @@ func TestCircuitMapPersistence(t *testing.T) {
// Check that the circuit map is empty, even after restarting. // Check that the circuit map is empty, even after restarting.
assertNumCircuitsWithHash(t, circuitMap, hash3, 0) assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash3, 0) assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
} }
@ -534,21 +647,24 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB {
// Creates a new circuit map, backed by a freshly opened channeldb. The existing // Creates a new circuit map, backed by a freshly opened channeldb. The existing
// channeldb is closed in order to simulate a complete restart. // channeldb is closed in order to simulate a complete restart.
func restartCircuitMap(t *testing.T, cdb *channeldb.DB) (*channeldb.DB, func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) (
htlcswitch.CircuitMap) { *htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) {
// Record the current temp path and close current db. // Record the current temp path and close current db.
dbPath := cdb.Path() dbPath := cfg.DB.Path()
cdb.Close() cfg.DB.Close()
// Reinitialize circuit map with same db path. // Reinitialize circuit map with same db path.
cdb2 := makeCircuitDB(t, dbPath) cfg2 := &htlcswitch.CircuitMapConfig{
cm2, err := htlcswitch.NewCircuitMap(cdb2) DB: makeCircuitDB(t, dbPath),
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
}
cm2, err := htlcswitch.NewCircuitMap(cfg2)
if err != nil { if err != nil {
t.Fatalf("unable to recreate persistent circuit map: %v", err) t.Fatalf("unable to recreate persistent circuit map: %v", err)
} }
return cdb2, cm2 return cfg2, cm2
} }
// TestCircuitMapCommitCircuits tests the following behavior of CommitCircuits: // TestCircuitMapCommitCircuits tests the following behavior of CommitCircuits:
@ -564,18 +680,14 @@ func TestCircuitMapCommitCircuits(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{ Incoming: htlcswitch.CircuitKey{
ChanID: chan1, ChanID: chan1,
HtlcID: 3, HtlcID: 3,
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: testExtracter,
} }
// First we will try to add an new circuit to the circuit map, this // First we will try to add an new circuit to the circuit map, this
@ -623,7 +735,7 @@ func TestCircuitMapCommitCircuits(t *testing.T) {
// to be loaded from disk. Since the keystone was never set, subsequent // to be loaded from disk. Since the keystone was never set, subsequent
// attempts to commit the circuit should cause the circuit map to // attempts to commit the circuit should cause the circuit map to
// indicate that that the HTLC should be failed back. // indicate that that the HTLC should be failed back.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
actions, err = circuitMap.CommitCircuits(circuit) actions, err = circuitMap.CommitCircuits(circuit)
if err != nil { if err != nil {
@ -664,18 +776,14 @@ func TestCircuitMapOpenCircuits(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{ Incoming: htlcswitch.CircuitKey{
ChanID: chan1, ChanID: chan1,
HtlcID: 3, HtlcID: 3,
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: testExtracter,
} }
// First we will try to add an new circuit to the circuit map, this // First we will try to add an new circuit to the circuit map, this
@ -747,7 +855,7 @@ func TestCircuitMapOpenCircuits(t *testing.T) {
// //
// NOTE: The channel db doesn't have any channel data, so no keystones // NOTE: The channel db doesn't have any channel data, so no keystones
// will be trimmed. // will be trimmed.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
// Check that we can still query for the open circuit. // Check that we can still query for the open circuit.
circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey) circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey)
@ -874,11 +982,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
const nCircuits = 10 const nCircuits = 10
const firstTrimIndex = 7 const firstTrimIndex = 7
@ -895,7 +999,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
ChanID: chan1, ChanID: chan1,
HtlcID: uint64(i + 3), HtlcID: uint64(i + 3),
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: htlcswitch.NewMockObfuscator(),
} }
} }
@ -953,7 +1057,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
// Restart the circuit map, verify that that the trim is reflected on // Restart the circuit map, verify that that the trim is reflected on
// startup. // startup.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertCircuitsOpenedPostRestart( assertCircuitsOpenedPostRestart(
t, t,
@ -995,7 +1099,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
// Restart the circuit map one last time to make sure the changes are // Restart the circuit map one last time to make sure the changes are
// persisted. // persisted.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
assertCircuitsOpenedPostRestart( assertCircuitsOpenedPostRestart(
t, t,
@ -1027,18 +1131,16 @@ func TestCircuitMapCloseOpenCircuits(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{ Incoming: htlcswitch.CircuitKey{
ChanID: chan1, ChanID: chan1,
HtlcID: 3, HtlcID: 3,
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: &htlcswitch.SphinxErrorEncrypter{
EphemeralKey: testEphemeralKey,
},
} }
// First we will try to add an new circuit to the circuit map, this // First we will try to add an new circuit to the circuit map, this
@ -1095,7 +1197,7 @@ func TestCircuitMapCloseOpenCircuits(t *testing.T) {
// //
// NOTE: The channel db doesn't have any channel data, so no keystones // NOTE: The channel db doesn't have any channel data, so no keystones
// will be trimmed. // will be trimmed.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
// Close the open circuit for the first time, which should succeed. // Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming) _, err = circuitMap.FailCircuit(circuit.Incoming)
@ -1122,18 +1224,14 @@ func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{ Incoming: htlcswitch.CircuitKey{
ChanID: chan1, ChanID: chan1,
HtlcID: 3, HtlcID: 3,
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: testExtracter,
} }
// First we will try to add an new circuit to the circuit map, this // First we will try to add an new circuit to the circuit map, this
@ -1157,7 +1255,7 @@ func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
// Now, restart the circuit map, which will result in the circuit being // Now, restart the circuit map, which will result in the circuit being
// reopened, since no attempt to delete the circuit was made. // reopened, since no attempt to delete the circuit was made.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
// Close the open circuit for the first time, which should succeed. // Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming) _, err = circuitMap.FailCircuit(circuit.Incoming)
@ -1183,18 +1281,14 @@ func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{ Incoming: htlcswitch.CircuitKey{
ChanID: chan1, ChanID: chan1,
HtlcID: 3, HtlcID: 3,
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: testExtracter,
} }
// First we will try to add an new circuit to the circuit map, this // First we will try to add an new circuit to the circuit map, this
@ -1225,7 +1319,7 @@ func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
// Now, restart the circuit map, and check that the deletion survived // Now, restart the circuit map, and check that the deletion survived
// the restart. // the restart.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
circuit2 = circuitMap.LookupCircuit(circuit.Incoming) circuit2 = circuitMap.LookupCircuit(circuit.Incoming)
if circuit2 != nil { if circuit2 != nil {
@ -1246,18 +1340,14 @@ func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
err error err error
) )
cdb := makeCircuitDB(t, "") cfg, circuitMap := newCircuitMap(t)
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
circuit := &htlcswitch.PaymentCircuit{ circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{ Incoming: htlcswitch.CircuitKey{
ChanID: chan1, ChanID: chan1,
HtlcID: 3, HtlcID: 3,
}, },
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), ErrorEncrypter: testExtracter,
} }
// First we will try to add an new circuit to the circuit map, this // First we will try to add an new circuit to the circuit map, this
@ -1302,7 +1392,7 @@ func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
// Now, restart the circuit map, and check that the deletion survived // Now, restart the circuit map, and check that the deletion survived
// the restart. // the restart.
cdb, circuitMap = restartCircuitMap(t, cdb) cfg, circuitMap = restartCircuitMap(t, cfg)
circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey) circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey)
if circuit2 != nil { if circuit2 != nil {

@ -74,7 +74,7 @@ func (e UnknownEncrypterType) Error() string {
// ErrorEncrypterExtracter defines a function signature that extracts an // ErrorEncrypterExtracter defines a function signature that extracts an
// ErrorEncrypter from an sphinx OnionPacket. // ErrorEncrypter from an sphinx OnionPacket.
type ErrorEncrypterExtracter func(*sphinx.OnionPacket) (ErrorEncrypter, type ErrorEncrypterExtracter func(*btcec.PublicKey) (ErrorEncrypter,
lnwire.FailCode) lnwire.FailCode)
// ErrorEncrypter is an interface that is used to encrypt HTLC related errors // ErrorEncrypter is an interface that is used to encrypt HTLC related errors
@ -96,11 +96,20 @@ type ErrorEncrypter interface {
// backing this interface. // backing this interface.
Type() EncrypterType Type() EncrypterType
// Encode serializes the encrypter to the given io.Writer. // Encode serializes the encrypter's ephemeral public key to the given
// io.Writer.
Encode(io.Writer) error Encode(io.Writer) error
// Decode deserializes the encrypter from the given io.Reader. // Decode deserializes the encrypter' ephemeral public key from the
// given io.Reader.
Decode(io.Reader) error Decode(io.Reader) error
// Reextract rederives the encrypter using the extracter, performing an
// ECDH with the sphinx router's key and the ephemeral public key.
//
// NOTE: This should be called shortly after Decode to properly
// reinitialize the error encrypter.
Reextract(ErrorEncrypterExtracter) error
} }
// SphinxErrorEncrypter is a concrete implementation of both the ErrorEncrypter // SphinxErrorEncrypter is a concrete implementation of both the ErrorEncrypter
@ -110,14 +119,20 @@ type ErrorEncrypter interface {
type SphinxErrorEncrypter struct { type SphinxErrorEncrypter struct {
*sphinx.OnionErrorEncrypter *sphinx.OnionErrorEncrypter
ogPacket *sphinx.OnionPacket EphemeralKey *btcec.PublicKey
} }
// NewSphinxErrorEncrypter initializes a new sphinx error encrypter as well as // NewSphinxErrorEncrypter initializes a blank sphinx error encrypter, that
// the embedded onion error encrypter. // should be used to deserialize an encoded SphinxErrorEncrypter. Since the
// actual encrypter is not stored in plaintext while at rest, reconstructing the
// error encrypter requires:
// 1) Decode: to deserialize the ephemeral public key.
// 2) Reextract: to "unlock" the actual error encrypter using an active
// OnionProcessor.
func NewSphinxErrorEncrypter() *SphinxErrorEncrypter { func NewSphinxErrorEncrypter() *SphinxErrorEncrypter {
return &SphinxErrorEncrypter{ return &SphinxErrorEncrypter{
OnionErrorEncrypter: &sphinx.OnionErrorEncrypter{}, OnionErrorEncrypter: nil,
EphemeralKey: &btcec.PublicKey{},
} }
} }
@ -154,17 +169,56 @@ func (s *SphinxErrorEncrypter) Type() EncrypterType {
return EncrypterTypeSphinx return EncrypterTypeSphinx
} }
// Encode serializes the error encrypter to the provided io.Writer. // Encode serializes the error encrypter' ephemeral public key to the provided
// io.Writer.
func (s *SphinxErrorEncrypter) Encode(w io.Writer) error { func (s *SphinxErrorEncrypter) Encode(w io.Writer) error {
return s.OnionErrorEncrypter.Encode(w) ephemeral := s.EphemeralKey.SerializeCompressed()
_, err := w.Write(ephemeral)
return err
} }
// Decode reconstructs the error encrypter from the provided io.Reader. // Decode reconstructs the error encrypter's ephemeral public key from the
// provided io.Reader.
func (s *SphinxErrorEncrypter) Decode(r io.Reader) error { func (s *SphinxErrorEncrypter) Decode(r io.Reader) error {
if s.OnionErrorEncrypter == nil { var ephemeral [33]byte
s.OnionErrorEncrypter = &sphinx.OnionErrorEncrypter{} if _, err := io.ReadFull(r, ephemeral[:]); err != nil {
return err
} }
return s.OnionErrorEncrypter.Decode(r)
var err error
s.EphemeralKey, err = btcec.ParsePubKey(ephemeral[:], btcec.S256())
if err != nil {
return err
}
return nil
}
// Reextract rederives the error encrypter from the currently held EphemeralKey.
// This intended to be used shortly after Decode, to fully initialize a
// SphinxErrorEncrypter.
func (s *SphinxErrorEncrypter) Reextract(
extract ErrorEncrypterExtracter) error {
obfuscator, failcode := extract(s.EphemeralKey)
if failcode != lnwire.CodeNone {
// This should never happen, since we already validated that
// this obfuscator can be extracted when it was received in the
// link.
return fmt.Errorf("unable to reconstruct onion "+
"obfuscator, got failcode: %d", failcode)
}
sphinxEncrypter, ok := obfuscator.(*SphinxErrorEncrypter)
if !ok {
return fmt.Errorf("incorrect onion error extracter")
}
// Copy the freshly extracted encrypter.
s.OnionErrorEncrypter = sphinxEncrypter.OnionErrorEncrypter
return nil
} }
// A compile time check to ensure SphinxErrorEncrypter implements the // A compile time check to ensure SphinxErrorEncrypter implements the

@ -6,6 +6,7 @@ import (
"github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
) )
// NetworkHop indicates the blockchain network that is intended to be the next // NetworkHop indicates the blockchain network that is intended to be the next
@ -167,7 +168,7 @@ func (r *sphinxHopIterator) ForwardingInstructions() ForwardingInfo {
func (r *sphinxHopIterator) ExtractErrorEncrypter( func (r *sphinxHopIterator) ExtractErrorEncrypter(
extracter ErrorEncrypterExtracter) (ErrorEncrypter, lnwire.FailCode) { extracter ErrorEncrypterExtracter) (ErrorEncrypter, lnwire.FailCode) {
return extracter(r.ogPacket) return extracter(r.ogPacket.EphemeralKey)
} }
// OnionProcessor is responsible for keeping all sphinx dependent parts inside // OnionProcessor is responsible for keeping all sphinx dependent parts inside
@ -401,11 +402,12 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte,
// ErrorEncrypter instance using the derived shared secret. In the case that en // ErrorEncrypter instance using the derived shared secret. In the case that en
// error occurs, a lnwire failure code detailing the parsing failure will be // error occurs, a lnwire failure code detailing the parsing failure will be
// returned. // returned.
func (p *OnionProcessor) ExtractErrorEncrypter(onionPkt *sphinx.OnionPacket) ( func (p *OnionProcessor) ExtractErrorEncrypter(ephemeralKey *btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) { ErrorEncrypter, lnwire.FailCode) {
onionObfuscator, err := sphinx.NewOnionErrorEncrypter(p.router, onionObfuscator, err := sphinx.NewOnionErrorEncrypter(
onionPkt.EphemeralKey) p.router, ephemeralKey,
)
if err != nil { if err != nil {
switch err { switch err {
case sphinx.ErrInvalidOnionVersion: case sphinx.ErrInvalidOnionVersion:
@ -422,6 +424,6 @@ func (p *OnionProcessor) ExtractErrorEncrypter(onionPkt *sphinx.OnionPacket) (
return &SphinxErrorEncrypter{ return &SphinxErrorEncrypter{
OnionErrorEncrypter: onionObfuscator, OnionErrorEncrypter: onionObfuscator,
ogPacket: onionPkt, EphemeralKey: ephemeralKey,
}, lnwire.CodeNone }, lnwire.CodeNone
} }

@ -147,9 +147,9 @@ type ChannelLinkConfig struct {
DecodeHopIterators func([]byte, []DecodeHopIteratorRequest) ( DecodeHopIterators func([]byte, []DecodeHopIteratorRequest) (
[]DecodeHopIteratorResponse, error) []DecodeHopIteratorResponse, error)
// DecodeOnionObfuscator function is responsible for decoding HTLC // ExtractErrorEncrypter function is responsible for decoding HTLC
// Sphinx onion blob, and creating onion failure obfuscator. // Sphinx onion blob, and creating onion failure obfuscator.
DecodeOnionObfuscator ErrorEncrypterExtracter ExtractErrorEncrypter ErrorEncrypterExtracter
// GetLastChannelUpdate retrieves the latest routing policy for this // GetLastChannelUpdate retrieves the latest routing policy for this
// particular channel. This will be used to provide payment senders our // particular channel. This will be used to provide payment senders our
@ -1843,7 +1843,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// Retrieve onion obfuscator from onion blob in order to // Retrieve onion obfuscator from onion blob in order to
// produce initial obfuscation of the onion failureCode. // produce initial obfuscation of the onion failureCode.
obfuscator, failureCode := chanIterator.ExtractErrorEncrypter( obfuscator, failureCode := chanIterator.ExtractErrorEncrypter(
l.cfg.DecodeOnionObfuscator, l.cfg.ExtractErrorEncrypter,
) )
if failureCode != lnwire.CodeNone { if failureCode != lnwire.CodeNone {
// If we're unable to process the onion blob than we // If we're unable to process the onion blob than we

@ -16,12 +16,12 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
"github.com/roasbeef/btcd/chaincfg/chainhash" "github.com/roasbeef/btcd/chaincfg/chainhash"
"github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil" "github.com/roasbeef/btcutil"
@ -1124,8 +1124,8 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) {
defer n.stop() defer n.stop()
// Replace decode function with another which throws an error. // Replace decode function with another which throws an error.
n.carolChannelLink.cfg.DecodeOnionObfuscator = func( n.carolChannelLink.cfg.ExtractErrorEncrypter = func(
*sphinx.OnionPacket) (ErrorEncrypter, lnwire.FailCode) { *btcec.PublicKey) (ErrorEncrypter, lnwire.FailCode) {
return nil, lnwire.CodeInvalidOnionVersion return nil, lnwire.CodeInvalidOnionVersion
} }
@ -1472,7 +1472,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
Circuits: aliceSwitch.CircuitModifier(), Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: aliceSwitch.ForwardPackets, ForwardPackets: aliceSwitch.ForwardPackets,
DecodeHopIterators: decoder.DecodeHopIterators, DecodeHopIterators: decoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) { ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone return obfuscator, lnwire.CodeNone
}, },

@ -302,6 +302,10 @@ func (o *mockObfuscator) Decode(r io.Reader) error {
return nil return nil
} }
func (o *mockObfuscator) Reextract(extracter ErrorEncrypterExtracter) error {
return nil
}
func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) ( func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) (
lnwire.OpaqueReason, error) { lnwire.OpaqueReason, error) {

@ -130,6 +130,11 @@ type Config struct {
// active channels. This gives the switch the ability to read arbitrary // active channels. This gives the switch the ability to read arbitrary
// forwarding packages, and ack settles and fails contained within them. // forwarding packages, and ack settles and fails contained within them.
SwitchPackager channeldb.FwdOperator SwitchPackager channeldb.FwdOperator
// ExtractErrorEncrypter is an interface allowing switch to reextract
// error encrypters stored in the circuit map on restarts, since they
// are not stored directly within the database.
ExtractErrorEncrypter ErrorEncrypterExtracter
} }
// Switch is the central messaging bus for all incoming/outgoing HTLCs. // Switch is the central messaging bus for all incoming/outgoing HTLCs.
@ -214,7 +219,10 @@ type Switch struct {
// New creates the new instance of htlc switch. // New creates the new instance of htlc switch.
func New(cfg Config) (*Switch, error) { func New(cfg Config) (*Switch, error) {
circuitMap, err := NewCircuitMap(cfg.DB) circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -17,7 +17,6 @@ import (
"github.com/btcsuite/fastsha256" "github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/go-errors/errors" "github.com/go-errors/errors"
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
@ -908,7 +907,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: aliceServer.htlcSwitch.CircuitModifier(), Circuits: aliceServer.htlcSwitch.CircuitModifier(),
ForwardPackets: aliceServer.htlcSwitch.ForwardPackets, ForwardPackets: aliceServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: aliceDecoder.DecodeHopIterators, DecodeHopIterators: aliceDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) { ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone return obfuscator, lnwire.CodeNone
}, },
@ -956,7 +955,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: bobServer.htlcSwitch.CircuitModifier(), Circuits: bobServer.htlcSwitch.CircuitModifier(),
ForwardPackets: bobServer.htlcSwitch.ForwardPackets, ForwardPackets: bobServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: bobDecoder.DecodeHopIterators, DecodeHopIterators: bobDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) { ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone return obfuscator, lnwire.CodeNone
}, },
@ -1004,7 +1003,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: bobServer.htlcSwitch.CircuitModifier(), Circuits: bobServer.htlcSwitch.CircuitModifier(),
ForwardPackets: bobServer.htlcSwitch.ForwardPackets, ForwardPackets: bobServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: bobDecoder.DecodeHopIterators, DecodeHopIterators: bobDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) { ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone return obfuscator, lnwire.CodeNone
}, },
@ -1052,7 +1051,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: carolServer.htlcSwitch.CircuitModifier(), Circuits: carolServer.htlcSwitch.CircuitModifier(),
ForwardPackets: carolServer.htlcSwitch.ForwardPackets, ForwardPackets: carolServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: carolDecoder.DecodeHopIterators, DecodeHopIterators: carolDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) ( ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) { ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone return obfuscator, lnwire.CodeNone
}, },

@ -399,7 +399,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
linkCfg := htlcswitch.ChannelLinkConfig{ linkCfg := htlcswitch.ChannelLinkConfig{
Peer: p, Peer: p,
DecodeHopIterators: p.server.sphinx.DecodeHopIterators, DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
DecodeOnionObfuscator: p.server.sphinx.ExtractErrorEncrypter, ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter, GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter,
p.PubKey(), lnChan.ShortChanID()), p.PubKey(), lnChan.ShortChanID()),
DebugHTLC: cfg.DebugHTLC, DebugHTLC: cfg.DebugHTLC,
@ -1366,7 +1366,7 @@ out:
linkConfig := htlcswitch.ChannelLinkConfig{ linkConfig := htlcswitch.ChannelLinkConfig{
Peer: p, Peer: p,
DecodeHopIterators: p.server.sphinx.DecodeHopIterators, DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
DecodeOnionObfuscator: p.server.sphinx.ExtractErrorEncrypter, ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter, GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter,
p.PubKey(), newChanReq.channel.ShortChanID()), p.PubKey(), newChanReq.channel.ShortChanID()),
DebugHTLC: cfg.DebugHTLC, DebugHTLC: cfg.DebugHTLC,

@ -232,8 +232,9 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
pubKey[:], err) pubKey[:], err)
} }
}, },
FwdingLog: chanDB.ForwardingLog(), FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(), SwitchPackager: channeldb.NewSwitchPackager(),
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
}) })
if err != nil { if err != nil {
return nil, err return nil, err