Merge pull request #826 from cfromknecht/reextract-circuit-encrypters

Reextract Circuit Error Encrypters
This commit is contained in:
Olaoluwa Osuntokun 2018-03-13 18:10:12 -07:00 committed by GitHub
commit 0fb7804e4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 320 additions and 137 deletions

@ -144,11 +144,7 @@ var (
// always identifiable by their incoming CircuitKey, in addition to their
// outgoing CircuitKey if the circuit is fully-opened.
type circuitMap struct {
// db provides the persistent storage engine for the circuit map.
//
// TODO(conner): create abstraction to allow for the substitution of
// other persistence engines.
db *channeldb.DB
cfg *CircuitMapConfig
mtx sync.RWMutex
@ -172,10 +168,23 @@ type circuitMap struct {
hashIndex map[[32]byte]map[CircuitKey]struct{}
}
// CircuitMapConfig houses the critical interfaces and references necessary to
// parameterize an instance of circuitMap.
type CircuitMapConfig struct {
// DB provides the persistent storage engine for the circuit map.
// TODO(conner): create abstraction to allow for the substitution of
// other persistence engines.
DB *channeldb.DB
// ExtractErrorEncrypter derives the shared secret used to encrypt
// errors from the obfuscator's ephemeral public key.
ExtractErrorEncrypter ErrorEncrypterExtracter
}
// NewCircuitMap creates a new instance of the circuitMap.
func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) {
func NewCircuitMap(cfg *CircuitMapConfig) (CircuitMap, error) {
cm := &circuitMap{
db: db,
cfg: cfg,
}
// Initialize the on-disk buckets used by the circuit map.
@ -203,7 +212,7 @@ func NewCircuitMap(db *channeldb.DB) (CircuitMap, error) {
// initBuckets ensures that the primary buckets used by the circuit are
// initialized so that we can assume their existence after startup.
func (cm *circuitMap) initBuckets() error {
return cm.db.Update(func(tx *bolt.Tx) error {
return cm.cfg.DB.Update(func(tx *bolt.Tx) error {
_, err := tx.CreateBucketIfNotExists(circuitKeystoneKey)
if err != nil {
return err
@ -226,7 +235,7 @@ func (cm *circuitMap) restoreMemState() error {
pending = make(map[CircuitKey]*PaymentCircuit)
)
if err := cm.db.View(func(tx *bolt.Tx) error {
if err := cm.cfg.DB.View(func(tx *bolt.Tx) error {
// Restore any of the circuits persisted in the circuit bucket
// back into memory.
circuitBkt := tx.Bucket(circuitAddKey)
@ -235,7 +244,7 @@ func (cm *circuitMap) restoreMemState() error {
}
if err := circuitBkt.ForEach(func(_, v []byte) error {
circuit, err := decodeCircuit(v)
circuit, err := cm.decodeCircuit(v)
if err != nil {
return err
}
@ -305,8 +314,9 @@ func (cm *circuitMap) restoreMemState() error {
// decodeCircuit reconstructs an in-memory payment circuit from a byte slice.
// The byte slice is assumed to have been generated by the circuit's Encode
// method.
func decodeCircuit(v []byte) (*PaymentCircuit, error) {
// method. If the decoding is successful, the onion obfuscator will be
// reextracted, since it is not stored in plaintext on disk.
func (cm *circuitMap) decodeCircuit(v []byte) (*PaymentCircuit, error) {
var circuit = &PaymentCircuit{}
circuitReader := bytes.NewReader(v)
@ -314,6 +324,21 @@ func decodeCircuit(v []byte) (*PaymentCircuit, error) {
return nil, err
}
// If the error encrypter is nil, this is locally-source payment so
// there is no encrypter.
if circuit.ErrorEncrypter == nil {
return circuit, nil
}
// Otherwise, we need to reextract the encrypter, so that the shared
// secret is rederived from what was decoded.
err := circuit.ErrorEncrypter.Reextract(
cm.cfg.ExtractErrorEncrypter,
)
if err != nil {
return nil, err
}
return circuit, nil
}
@ -325,7 +350,7 @@ func decodeCircuit(v []byte) (*PaymentCircuit, error) {
// channels. Therefore, it must be called before any links are created to avoid
// interfering with normal operation.
func (cm *circuitMap) trimAllOpenCircuits() error {
activeChannels, err := cm.db.FetchAllChannels()
activeChannels, err := cm.cfg.DB.FetchAllChannels()
if err != nil {
return err
}
@ -385,7 +410,7 @@ func (cm *circuitMap) TrimOpenCircuits(chanID lnwire.ShortChannelID,
return nil
}
return cm.db.Update(func(tx *bolt.Tx) error {
return cm.cfg.DB.Update(func(tx *bolt.Tx) error {
keystoneBkt := tx.Bucket(circuitKeystoneKey)
if keystoneBkt == nil {
return ErrCorruptedCircuitMap
@ -533,7 +558,7 @@ func (cm *circuitMap) CommitCircuits(circuits ...*PaymentCircuit) (
// Write the entire batch of circuits to the persistent circuit bucket
// using bolt's Batch write. This method must be called from multiple,
// distinct goroutines to have any impact on performance.
err := cm.db.Batch(func(tx *bolt.Tx) error {
err := cm.cfg.DB.Batch(func(tx *bolt.Tx) error {
circuitBkt := tx.Bucket(circuitAddKey)
if circuitBkt == nil {
return ErrCorruptedCircuitMap
@ -623,7 +648,7 @@ func (cm *circuitMap) OpenCircuits(keystones ...Keystone) error {
}
cm.mtx.RUnlock()
err := cm.db.Update(func(tx *bolt.Tx) error {
err := cm.cfg.DB.Update(func(tx *bolt.Tx) error {
// Now, load the circuit bucket to which we will write the
// already serialized circuit.
keystoneBkt := tx.Bucket(circuitKeystoneKey)
@ -769,7 +794,7 @@ func (cm *circuitMap) DeleteCircuits(inKeys ...CircuitKey) error {
}
cm.mtx.Unlock()
err := cm.db.Batch(func(tx *bolt.Tx) error {
err := cm.cfg.DB.Batch(func(tx *bolt.Tx) error {
for _, circuit := range removedCircuits {
// If this htlc made it to an outgoing link, load the
// keystone bucket from which we will remove the

@ -6,9 +6,12 @@ import (
"reflect"
"testing"
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
bitcoinCfg "github.com/roasbeef/btcd/chaincfg"
"github.com/roasbeef/btcutil"
)
@ -16,19 +19,114 @@ var (
hash1 = [32]byte{0x01}
hash2 = [32]byte{0x02}
hash3 = [32]byte{0x03}
// sphinxPrivKey is the private key given to freshly created sphinx
// routers.
sphinxPrivKey *btcec.PrivateKey
// testEphemeralKey is the ephemeral key that will be extracted to
// create onion obfuscators.
testEphemeralKey *btcec.PublicKey
// testExtracter is a precomputed extraction of testEphemeralKey, using
// the sphinxPrivKey.
testExtracter *htlcswitch.SphinxErrorEncrypter
)
func TestCircuitMapInit(t *testing.T) {
t.Parallel()
func init() {
// Generate a fresh key for our sphinx router.
var err error
sphinxPrivKey, err = btcec.NewPrivateKey(btcec.S256())
if err != nil {
panic(err)
}
// Initialize new database for circuit map.
cdb := makeCircuitDB(t, "")
_, err := htlcswitch.NewCircuitMap(cdb)
// And another, whose public key will serve as the test ephemeral key.
testEphemeralPriv, err := btcec.NewPrivateKey(btcec.S256())
if err != nil {
panic(err)
}
testEphemeralKey = testEphemeralPriv.PubKey()
// Finally, properly initialize the test extracter
initTestExtracter()
}
// initTestExtracter spins up a new onion processor specifically for the purpose
// of generating our testExtracter, which should be derived from the
// testEphemeralKey, and which randomly-generated key is used to init the sphinx
// router.
//
// NOTE: This should be called in init(), after testEphemeralKey has been
// properly initialized.
func initTestExtracter() {
onionProcessor := newOnionProcessor(nil)
defer onionProcessor.Stop()
obfuscator, _ := onionProcessor.ExtractErrorEncrypter(
testEphemeralKey,
)
sphinxExtracter, ok := obfuscator.(*htlcswitch.SphinxErrorEncrypter)
if !ok {
panic("did not extract sphinx error encrypter")
}
testExtracter = sphinxExtracter
// We also set this error extracter on startup, otherwise it will be nil
// at compile-time.
halfCircuitTests[2].encrypter = testExtracter
}
// newOnionProcessor creates starts a new htlcswitch.OnionProcessor using a temp
// db and no garbage collection.
func newOnionProcessor(t *testing.T) *htlcswitch.OnionProcessor {
sharedSecretFile, err := ioutil.TempFile("", "sphinxreplay.db")
if err != nil {
t.Fatalf("unable to create temp path: %v", err)
}
sharedSecretPath := sharedSecretFile.Name()
sphinxRouter := sphinx.NewRouter(
sharedSecretPath, sphinxPrivKey, &bitcoinCfg.SimNetParams, nil,
)
if err := sphinxRouter.Start(); err != nil {
t.Fatalf("unable to start sphinx router: %v", err)
}
return htlcswitch.NewOnionProcessor(sphinxRouter)
}
// newCircuitMap creates a new htlcswitch.CircuitMap using a temp db and a
// fresh sphinx router.
func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
htlcswitch.CircuitMap) {
onionProcessor := newOnionProcessor(t)
circuitMapCfg := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, ""),
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
}
circuitMap, err := htlcswitch.NewCircuitMap(circuitMapCfg)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
restartCircuitMap(t, cdb)
return circuitMapCfg, circuitMap
}
// TestCircuitMapInit is a quick check to ensure that we can start and restore
// the circuit map, as this will be used extensively in this suite.
func TestCircuitMapInit(t *testing.T) {
t.Parallel()
cfg, _ := newCircuitMap(t)
restartCircuitMap(t, cfg)
}
var halfCircuitTests = []struct {
@ -56,12 +154,15 @@ var halfCircuitTests = []struct {
encrypter: htlcswitch.NewMockObfuscator(),
},
{
hash: hash3,
inValue: 10000,
outValue: 9000,
chanID: lnwire.NewShortChanIDFromInt(3),
htlcID: 3,
encrypter: htlcswitch.NewSphinxErrorEncrypter(),
hash: hash3,
inValue: 10000,
outValue: 9000,
chanID: lnwire.NewShortChanIDFromInt(3),
htlcID: 3,
// NOTE: The value of testExtracter is nil at compile-time, it
// is fully-initialized in initTestExtracter, which should
// repopulate this encrypter.
encrypter: testExtracter,
},
}
@ -72,6 +173,8 @@ var halfCircuitTests = []struct {
func TestHalfCircuitSerialization(t *testing.T) {
t.Parallel()
onionProcessor := newOnionProcessor(t)
for i, test := range halfCircuitTests {
circuit := &htlcswitch.PaymentCircuit{
PaymentHash: test.hash,
@ -97,6 +200,20 @@ func TestHalfCircuitSerialization(t *testing.T) {
t.Fatalf("unable to decode half payment circuit test=%d: %v", i, err)
}
// If the error encrypter is initialized, we will need to
// reextract it from it's decoded state, as this requires an
// ECDH with the onion processor's private key. For mock error
// encrypters, this will be a NOP.
if circuit2.ErrorEncrypter != nil {
err := circuit2.ErrorEncrypter.Reextract(
onionProcessor.ExtractErrorEncrypter,
)
if err != nil {
t.Fatalf("unable to reextract sphinx error "+
"encrypter: %v", err)
}
}
// Reconstructed half circuit should match the original.
if !equalIgnoreLFD(circuit, &circuit2) {
t.Fatalf("unexpected half circuit test=%d, want %v, got %v",
@ -115,11 +232,7 @@ func TestCircuitMapPersistence(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{chan1, 0})
if circuit != nil {
@ -143,7 +256,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertHasCircuit(t, circuitMap, circuit1)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertHasCircuit(t, circuitMap, circuit1)
@ -168,7 +281,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertHasCircuit(t, circuitMap, circuit1)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1)
@ -213,7 +326,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit1)
@ -238,7 +351,7 @@ func TestCircuitMapPersistence(t *testing.T) {
}
assertHasCircuit(t, circuitMap, circuit3)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertHasCircuit(t, circuitMap, circuit3)
// Add another circuit with an already-used HTLC ID but different
@ -260,7 +373,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1)
assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2)
assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3)
@ -294,7 +407,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit4)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuit(t, circuitMap, circuit4)
@ -335,7 +448,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertHasCircuitForHash(t, circuitMap, hash3, circuit3)
// Restart, then run checks again.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
// Verify that all circuits have been fully added.
assertHasCircuit(t, circuitMap, circuit1)
@ -368,7 +481,7 @@ func TestCircuitMapPersistence(t *testing.T) {
// should be circuit4.
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 1)
assertHasCircuitForHash(t, circuitMap, hash1, circuit4)
@ -391,7 +504,7 @@ func TestCircuitMapPersistence(t *testing.T) {
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash1, 0)
assertNumCircuitsWithHash(t, circuitMap, hash2, 1)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
@ -405,7 +518,7 @@ func TestCircuitMapPersistence(t *testing.T) {
// There should now only be one remaining circuit, with hash3.
assertNumCircuitsWithHash(t, circuitMap, hash2, 0)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash2, 0)
assertNumCircuitsWithHash(t, circuitMap, hash3, 1)
@ -417,7 +530,7 @@ func TestCircuitMapPersistence(t *testing.T) {
// Check that the circuit map is empty, even after restarting.
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertNumCircuitsWithHash(t, circuitMap, hash3, 0)
}
@ -534,21 +647,24 @@ func makeCircuitDB(t *testing.T, path string) *channeldb.DB {
// Creates a new circuit map, backed by a freshly opened channeldb. The existing
// channeldb is closed in order to simulate a complete restart.
func restartCircuitMap(t *testing.T, cdb *channeldb.DB) (*channeldb.DB,
htlcswitch.CircuitMap) {
func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) (
*htlcswitch.CircuitMapConfig, htlcswitch.CircuitMap) {
// Record the current temp path and close current db.
dbPath := cdb.Path()
cdb.Close()
dbPath := cfg.DB.Path()
cfg.DB.Close()
// Reinitialize circuit map with same db path.
cdb2 := makeCircuitDB(t, dbPath)
cm2, err := htlcswitch.NewCircuitMap(cdb2)
cfg2 := &htlcswitch.CircuitMapConfig{
DB: makeCircuitDB(t, dbPath),
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
}
cm2, err := htlcswitch.NewCircuitMap(cfg2)
if err != nil {
t.Fatalf("unable to recreate persistent circuit map: %v", err)
}
return cdb2, cm2
return cfg2, cm2
}
// TestCircuitMapCommitCircuits tests the following behavior of CommitCircuits:
@ -564,18 +680,14 @@ func TestCircuitMapCommitCircuits(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: testExtracter,
}
// First we will try to add an new circuit to the circuit map, this
@ -623,7 +735,7 @@ func TestCircuitMapCommitCircuits(t *testing.T) {
// to be loaded from disk. Since the keystone was never set, subsequent
// attempts to commit the circuit should cause the circuit map to
// indicate that that the HTLC should be failed back.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
actions, err = circuitMap.CommitCircuits(circuit)
if err != nil {
@ -664,18 +776,14 @@ func TestCircuitMapOpenCircuits(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: testExtracter,
}
// First we will try to add an new circuit to the circuit map, this
@ -747,7 +855,7 @@ func TestCircuitMapOpenCircuits(t *testing.T) {
//
// NOTE: The channel db doesn't have any channel data, so no keystones
// will be trimmed.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
// Check that we can still query for the open circuit.
circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey)
@ -874,11 +982,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
const nCircuits = 10
const firstTrimIndex = 7
@ -895,7 +999,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
ChanID: chan1,
HtlcID: uint64(i + 3),
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: htlcswitch.NewMockObfuscator(),
}
}
@ -953,7 +1057,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
// Restart the circuit map, verify that that the trim is reflected on
// startup.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertCircuitsOpenedPostRestart(
t,
@ -995,7 +1099,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
// Restart the circuit map one last time to make sure the changes are
// persisted.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
assertCircuitsOpenedPostRestart(
t,
@ -1027,18 +1131,16 @@ func TestCircuitMapCloseOpenCircuits(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: &htlcswitch.SphinxErrorEncrypter{
EphemeralKey: testEphemeralKey,
},
}
// First we will try to add an new circuit to the circuit map, this
@ -1095,7 +1197,7 @@ func TestCircuitMapCloseOpenCircuits(t *testing.T) {
//
// NOTE: The channel db doesn't have any channel data, so no keystones
// will be trimmed.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
@ -1122,18 +1224,14 @@ func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: testExtracter,
}
// First we will try to add an new circuit to the circuit map, this
@ -1157,7 +1255,7 @@ func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
// Now, restart the circuit map, which will result in the circuit being
// reopened, since no attempt to delete the circuit was made.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
// Close the open circuit for the first time, which should succeed.
_, err = circuitMap.FailCircuit(circuit.Incoming)
@ -1183,18 +1281,14 @@ func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: testExtracter,
}
// First we will try to add an new circuit to the circuit map, this
@ -1225,7 +1319,7 @@ func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
// Now, restart the circuit map, and check that the deletion survived
// the restart.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
circuit2 = circuitMap.LookupCircuit(circuit.Incoming)
if circuit2 != nil {
@ -1246,18 +1340,14 @@ func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
err error
)
cdb := makeCircuitDB(t, "")
circuitMap, err = htlcswitch.NewCircuitMap(cdb)
if err != nil {
t.Fatalf("unable to create persistent circuit map: %v", err)
}
cfg, circuitMap := newCircuitMap(t)
circuit := &htlcswitch.PaymentCircuit{
Incoming: htlcswitch.CircuitKey{
ChanID: chan1,
HtlcID: 3,
},
ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(),
ErrorEncrypter: testExtracter,
}
// First we will try to add an new circuit to the circuit map, this
@ -1302,7 +1392,7 @@ func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
// Now, restart the circuit map, and check that the deletion survived
// the restart.
cdb, circuitMap = restartCircuitMap(t, cdb)
cfg, circuitMap = restartCircuitMap(t, cfg)
circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey)
if circuit2 != nil {

@ -74,7 +74,7 @@ func (e UnknownEncrypterType) Error() string {
// ErrorEncrypterExtracter defines a function signature that extracts an
// ErrorEncrypter from an sphinx OnionPacket.
type ErrorEncrypterExtracter func(*sphinx.OnionPacket) (ErrorEncrypter,
type ErrorEncrypterExtracter func(*btcec.PublicKey) (ErrorEncrypter,
lnwire.FailCode)
// ErrorEncrypter is an interface that is used to encrypt HTLC related errors
@ -96,11 +96,20 @@ type ErrorEncrypter interface {
// backing this interface.
Type() EncrypterType
// Encode serializes the encrypter to the given io.Writer.
// Encode serializes the encrypter's ephemeral public key to the given
// io.Writer.
Encode(io.Writer) error
// Decode deserializes the encrypter from the given io.Reader.
// Decode deserializes the encrypter' ephemeral public key from the
// given io.Reader.
Decode(io.Reader) error
// Reextract rederives the encrypter using the extracter, performing an
// ECDH with the sphinx router's key and the ephemeral public key.
//
// NOTE: This should be called shortly after Decode to properly
// reinitialize the error encrypter.
Reextract(ErrorEncrypterExtracter) error
}
// SphinxErrorEncrypter is a concrete implementation of both the ErrorEncrypter
@ -110,14 +119,20 @@ type ErrorEncrypter interface {
type SphinxErrorEncrypter struct {
*sphinx.OnionErrorEncrypter
ogPacket *sphinx.OnionPacket
EphemeralKey *btcec.PublicKey
}
// NewSphinxErrorEncrypter initializes a new sphinx error encrypter as well as
// the embedded onion error encrypter.
// NewSphinxErrorEncrypter initializes a blank sphinx error encrypter, that
// should be used to deserialize an encoded SphinxErrorEncrypter. Since the
// actual encrypter is not stored in plaintext while at rest, reconstructing the
// error encrypter requires:
// 1) Decode: to deserialize the ephemeral public key.
// 2) Reextract: to "unlock" the actual error encrypter using an active
// OnionProcessor.
func NewSphinxErrorEncrypter() *SphinxErrorEncrypter {
return &SphinxErrorEncrypter{
OnionErrorEncrypter: &sphinx.OnionErrorEncrypter{},
OnionErrorEncrypter: nil,
EphemeralKey: &btcec.PublicKey{},
}
}
@ -154,17 +169,56 @@ func (s *SphinxErrorEncrypter) Type() EncrypterType {
return EncrypterTypeSphinx
}
// Encode serializes the error encrypter to the provided io.Writer.
// Encode serializes the error encrypter' ephemeral public key to the provided
// io.Writer.
func (s *SphinxErrorEncrypter) Encode(w io.Writer) error {
return s.OnionErrorEncrypter.Encode(w)
ephemeral := s.EphemeralKey.SerializeCompressed()
_, err := w.Write(ephemeral)
return err
}
// Decode reconstructs the error encrypter from the provided io.Reader.
// Decode reconstructs the error encrypter's ephemeral public key from the
// provided io.Reader.
func (s *SphinxErrorEncrypter) Decode(r io.Reader) error {
if s.OnionErrorEncrypter == nil {
s.OnionErrorEncrypter = &sphinx.OnionErrorEncrypter{}
var ephemeral [33]byte
if _, err := io.ReadFull(r, ephemeral[:]); err != nil {
return err
}
return s.OnionErrorEncrypter.Decode(r)
var err error
s.EphemeralKey, err = btcec.ParsePubKey(ephemeral[:], btcec.S256())
if err != nil {
return err
}
return nil
}
// Reextract rederives the error encrypter from the currently held EphemeralKey.
// This intended to be used shortly after Decode, to fully initialize a
// SphinxErrorEncrypter.
func (s *SphinxErrorEncrypter) Reextract(
extract ErrorEncrypterExtracter) error {
obfuscator, failcode := extract(s.EphemeralKey)
if failcode != lnwire.CodeNone {
// This should never happen, since we already validated that
// this obfuscator can be extracted when it was received in the
// link.
return fmt.Errorf("unable to reconstruct onion "+
"obfuscator, got failcode: %d", failcode)
}
sphinxEncrypter, ok := obfuscator.(*SphinxErrorEncrypter)
if !ok {
return fmt.Errorf("incorrect onion error extracter")
}
// Copy the freshly extracted encrypter.
s.OnionErrorEncrypter = sphinxEncrypter.OnionErrorEncrypter
return nil
}
// A compile time check to ensure SphinxErrorEncrypter implements the

@ -6,6 +6,7 @@ import (
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
)
// NetworkHop indicates the blockchain network that is intended to be the next
@ -167,7 +168,7 @@ func (r *sphinxHopIterator) ForwardingInstructions() ForwardingInfo {
func (r *sphinxHopIterator) ExtractErrorEncrypter(
extracter ErrorEncrypterExtracter) (ErrorEncrypter, lnwire.FailCode) {
return extracter(r.ogPacket)
return extracter(r.ogPacket.EphemeralKey)
}
// OnionProcessor is responsible for keeping all sphinx dependent parts inside
@ -401,11 +402,12 @@ func (p *OnionProcessor) DecodeHopIterators(id []byte,
// ErrorEncrypter instance using the derived shared secret. In the case that en
// error occurs, a lnwire failure code detailing the parsing failure will be
// returned.
func (p *OnionProcessor) ExtractErrorEncrypter(onionPkt *sphinx.OnionPacket) (
func (p *OnionProcessor) ExtractErrorEncrypter(ephemeralKey *btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
onionObfuscator, err := sphinx.NewOnionErrorEncrypter(p.router,
onionPkt.EphemeralKey)
onionObfuscator, err := sphinx.NewOnionErrorEncrypter(
p.router, ephemeralKey,
)
if err != nil {
switch err {
case sphinx.ErrInvalidOnionVersion:
@ -422,6 +424,6 @@ func (p *OnionProcessor) ExtractErrorEncrypter(onionPkt *sphinx.OnionPacket) (
return &SphinxErrorEncrypter{
OnionErrorEncrypter: onionObfuscator,
ogPacket: onionPkt,
EphemeralKey: ephemeralKey,
}, lnwire.CodeNone
}

@ -147,9 +147,9 @@ type ChannelLinkConfig struct {
DecodeHopIterators func([]byte, []DecodeHopIteratorRequest) (
[]DecodeHopIteratorResponse, error)
// DecodeOnionObfuscator function is responsible for decoding HTLC
// ExtractErrorEncrypter function is responsible for decoding HTLC
// Sphinx onion blob, and creating onion failure obfuscator.
DecodeOnionObfuscator ErrorEncrypterExtracter
ExtractErrorEncrypter ErrorEncrypterExtracter
// GetLastChannelUpdate retrieves the latest routing policy for this
// particular channel. This will be used to provide payment senders our
@ -1843,7 +1843,7 @@ func (l *channelLink) processRemoteAdds(fwdPkg *channeldb.FwdPkg,
// Retrieve onion obfuscator from onion blob in order to
// produce initial obfuscation of the onion failureCode.
obfuscator, failureCode := chanIterator.ExtractErrorEncrypter(
l.cfg.DecodeOnionObfuscator,
l.cfg.ExtractErrorEncrypter,
)
if failureCode != lnwire.CodeNone {
// If we're unable to process the onion blob than we

@ -16,12 +16,12 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec"
"github.com/roasbeef/btcd/chaincfg/chainhash"
"github.com/roasbeef/btcd/wire"
"github.com/roasbeef/btcutil"
@ -1124,8 +1124,8 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) {
defer n.stop()
// Replace decode function with another which throws an error.
n.carolChannelLink.cfg.DecodeOnionObfuscator = func(
*sphinx.OnionPacket) (ErrorEncrypter, lnwire.FailCode) {
n.carolChannelLink.cfg.ExtractErrorEncrypter = func(
*btcec.PublicKey) (ErrorEncrypter, lnwire.FailCode) {
return nil, lnwire.CodeInvalidOnionVersion
}
@ -1472,7 +1472,7 @@ func newSingleLinkTestHarness(chanAmt, chanReserve btcutil.Amount) (
Circuits: aliceSwitch.CircuitModifier(),
ForwardPackets: aliceSwitch.ForwardPackets,
DecodeHopIterators: decoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},

@ -302,6 +302,10 @@ func (o *mockObfuscator) Decode(r io.Reader) error {
return nil
}
func (o *mockObfuscator) Reextract(extracter ErrorEncrypterExtracter) error {
return nil
}
func (o *mockObfuscator) EncryptFirstHop(failure lnwire.FailureMessage) (
lnwire.OpaqueReason, error) {

@ -130,6 +130,11 @@ type Config struct {
// active channels. This gives the switch the ability to read arbitrary
// forwarding packages, and ack settles and fails contained within them.
SwitchPackager channeldb.FwdOperator
// ExtractErrorEncrypter is an interface allowing switch to reextract
// error encrypters stored in the circuit map on restarts, since they
// are not stored directly within the database.
ExtractErrorEncrypter ErrorEncrypterExtracter
}
// Switch is the central messaging bus for all incoming/outgoing HTLCs.
@ -214,7 +219,10 @@ type Switch struct {
// New creates the new instance of htlc switch.
func New(cfg Config) (*Switch, error) {
circuitMap, err := NewCircuitMap(cfg.DB)
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
DB: cfg.DB,
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
})
if err != nil {
return nil, err
}

@ -17,7 +17,6 @@ import (
"github.com/btcsuite/fastsha256"
"github.com/coreos/bbolt"
"github.com/go-errors/errors"
"github.com/lightningnetwork/lightning-onion"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/contractcourt"
@ -908,7 +907,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: aliceServer.htlcSwitch.CircuitModifier(),
ForwardPackets: aliceServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: aliceDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
@ -956,7 +955,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: bobServer.htlcSwitch.CircuitModifier(),
ForwardPackets: bobServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: bobDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
@ -1004,7 +1003,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: bobServer.htlcSwitch.CircuitModifier(),
ForwardPackets: bobServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: bobDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},
@ -1052,7 +1051,7 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel,
Circuits: carolServer.htlcSwitch.CircuitModifier(),
ForwardPackets: carolServer.htlcSwitch.ForwardPackets,
DecodeHopIterators: carolDecoder.DecodeHopIterators,
DecodeOnionObfuscator: func(*sphinx.OnionPacket) (
ExtractErrorEncrypter: func(*btcec.PublicKey) (
ErrorEncrypter, lnwire.FailCode) {
return obfuscator, lnwire.CodeNone
},

@ -399,7 +399,7 @@ func (p *peer) loadActiveChannels(chans []*channeldb.OpenChannel) error {
linkCfg := htlcswitch.ChannelLinkConfig{
Peer: p,
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
DecodeOnionObfuscator: p.server.sphinx.ExtractErrorEncrypter,
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter,
p.PubKey(), lnChan.ShortChanID()),
DebugHTLC: cfg.DebugHTLC,
@ -1366,7 +1366,7 @@ out:
linkConfig := htlcswitch.ChannelLinkConfig{
Peer: p,
DecodeHopIterators: p.server.sphinx.DecodeHopIterators,
DecodeOnionObfuscator: p.server.sphinx.ExtractErrorEncrypter,
ExtractErrorEncrypter: p.server.sphinx.ExtractErrorEncrypter,
GetLastChannelUpdate: createGetLastUpdate(p.server.chanRouter,
p.PubKey(), newChanReq.channel.ShortChanID()),
DebugHTLC: cfg.DebugHTLC,

@ -232,8 +232,9 @@ func newServer(listenAddrs []string, chanDB *channeldb.DB, cc *chainControl,
pubKey[:], err)
}
},
FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(),
FwdingLog: chanDB.ForwardingLog(),
SwitchPackager: channeldb.NewSwitchPackager(),
ExtractErrorEncrypter: s.sphinx.ExtractErrorEncrypter,
})
if err != nil {
return nil, err