diff --git a/htlcswitch/circuit_test.go b/htlcswitch/circuit_test.go index 27bc7e58..2c777524 100644 --- a/htlcswitch/circuit_test.go +++ b/htlcswitch/circuit_test.go @@ -1,155 +1,1312 @@ package htlcswitch_test import ( + "bytes" + "io/ioutil" + "reflect" "testing" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwire" + "github.com/roasbeef/btcutil" ) -func TestCircuitMap(t *testing.T) { +var ( + hash1 = [32]byte{0x01} + hash2 = [32]byte{0x02} + hash3 = [32]byte{0x03} +) + +func TestCircuitMapInit(t *testing.T) { t.Parallel() - var hash1, hash2, hash3 [32]byte - hash1[0] = 1 - hash2[0] = 2 - hash3[0] = 3 + // Initialize new database for circuit map. + cdb := makeCircuitDB(t, "") + _, err := htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + restartCircuitMap(t, cdb) +} + +var halfCircuitTests = []struct { + hash [32]byte + inValue btcutil.Amount + outValue btcutil.Amount + chanID lnwire.ShortChannelID + htlcID uint64 + encrypter htlcswitch.ErrorEncrypter +}{ + { + hash: hash1, + inValue: 0, + outValue: 1000, + chanID: lnwire.NewShortChanIDFromInt(1), + htlcID: 1, + encrypter: nil, + }, + { + hash: hash2, + inValue: 2100, + outValue: 2000, + chanID: lnwire.NewShortChanIDFromInt(2), + htlcID: 2, + encrypter: htlcswitch.NewMockObfuscator(), + }, + { + hash: hash3, + inValue: 10000, + outValue: 9000, + chanID: lnwire.NewShortChanIDFromInt(3), + htlcID: 3, + encrypter: htlcswitch.NewSphinxErrorEncrypter(), + }, +} + +// TestHalfCircuitSerialization checks that the half circuits can be properly +// encoded and decoded properly. A critical responsibility of this test is to +// verify that the various ErrorEncrypter implementations can be properly +// reconstructed from a serialized half circuit. +func TestHalfCircuitSerialization(t *testing.T) { + t.Parallel() + + for i, test := range halfCircuitTests { + circuit := &htlcswitch.PaymentCircuit{ + PaymentHash: test.hash, + IncomingAmount: lnwire.NewMSatFromSatoshis(test.inValue), + OutgoingAmount: lnwire.NewMSatFromSatoshis(test.outValue), + Incoming: htlcswitch.CircuitKey{ + ChanID: test.chanID, + HtlcID: test.htlcID, + }, + ErrorEncrypter: test.encrypter, + } + + // Write the half circuit to our buffer. + var b bytes.Buffer + if err := circuit.Encode(&b); err != nil { + t.Fatalf("unable to encode half payment circuit test=%d: %v", i, err) + } + + // Then try to decode the serialized bytes. + var circuit2 htlcswitch.PaymentCircuit + circuitReader := bytes.NewReader(b.Bytes()) + if err := circuit2.Decode(circuitReader); err != nil { + t.Fatalf("unable to decode half payment circuit test=%d: %v", i, err) + } + + // Reconstructed half circuit should match the original. + if !equalIgnoreLFD(circuit, &circuit2) { + t.Fatalf("unexpected half circuit test=%d, want %v, got %v", + i, circuit, circuit2) + } + } +} + +func TestCircuitMapPersistence(t *testing.T) { + t.Parallel() var ( - chan1 = lnwire.NewShortChanIDFromInt(1) - chan2 = lnwire.NewShortChanIDFromInt(2) + chan1 = lnwire.NewShortChanIDFromInt(1) + chan2 = lnwire.NewShortChanIDFromInt(2) + circuitMap htlcswitch.CircuitMap + err error ) - circuitMap := htlcswitch.NewCircuitMap() + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } - circuit := circuitMap.LookupByHTLC(chan1, 0) + circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{chan1, 0}) if circuit != nil { t.Fatalf("LookupByHTLC returned a circuit before any were added: %v", circuit) } + circuit1 := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 1, + }, + PaymentHash: hash1, + ErrorEncrypter: htlcswitch.NewMockObfuscator(), + } + if _, err := circuitMap.CommitCircuits(circuit1); err != nil { + t.Fatalf("unable to add half circuit: %v", err) + } + + // Circuit map should have one circuit that has not been fully opened. + assertNumCircuitsWithHash(t, circuitMap, hash1, 0) + assertHasCircuit(t, circuitMap, circuit1) + + cdb, circuitMap = restartCircuitMap(t, cdb) + + assertNumCircuitsWithHash(t, circuitMap, hash1, 0) + assertHasCircuit(t, circuitMap, circuit1) + // Add multiple circuits with same destination channel but different HTLC // IDs and payment hashes. - circuitMap.Add(&htlcswitch.PaymentCircuit{ - PaymentHash: hash1, - IncomingChanID: chan2, - IncomingHTLCID: 1, - OutgoingChanID: chan1, - OutgoingHTLCID: 0, - }) + keystone1 := htlcswitch.Keystone{ + InKey: circuit1.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 0, + }, + } + circuit1.Outgoing = &keystone1.OutKey + if err := circuitMap.OpenCircuits(keystone1); err != nil { + t.Fatalf("unable to add full circuit: %v", err) + } - circuitMap.Add(&htlcswitch.PaymentCircuit{ + // Circuit map should reflect addition of circuit1, and the change + // should survive a restart. + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuit(t, circuitMap, circuit1) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + + cdb, circuitMap = restartCircuitMap(t, cdb) + + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuit(t, circuitMap, circuit1) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + + circuit2 := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 2, + }, PaymentHash: hash2, - IncomingChanID: chan2, - IncomingHTLCID: 2, - OutgoingChanID: chan1, - OutgoingHTLCID: 1, - }) + ErrorEncrypter: htlcswitch.NewMockObfuscator(), + } + if _, err := circuitMap.CommitCircuits(circuit2); err != nil { + t.Fatalf("unable to add half circuit: %v", err) + } + + assertHasCircuit(t, circuitMap, circuit2) + + keystone2 := htlcswitch.Keystone{ + InKey: circuit2.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 1, + }, + } + circuit2.Outgoing = &keystone2.OutKey + if err := circuitMap.OpenCircuits(keystone2); err != nil { + t.Fatalf("unable to add full circuit: %v", err) + } + + // Should have two full circuits, one under hash1 and another under + // hash2. Both half payment circuits should have been removed when the + // full circuits were added. + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuit(t, circuitMap, circuit1) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + + assertNumCircuitsWithHash(t, circuitMap, hash2, 1) + assertHasCircuit(t, circuitMap, circuit2) + assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) + + assertNumCircuitsWithHash(t, circuitMap, hash3, 0) + + cdb, circuitMap = restartCircuitMap(t, cdb) + + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuit(t, circuitMap, circuit1) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + + assertNumCircuitsWithHash(t, circuitMap, hash2, 1) + assertHasCircuit(t, circuitMap, circuit2) + assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) + + assertNumCircuitsWithHash(t, circuitMap, hash3, 0) + + circuit3 := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 2, + }, + PaymentHash: hash3, + ErrorEncrypter: htlcswitch.NewMockObfuscator(), + } + if _, err := circuitMap.CommitCircuits(circuit3); err != nil { + t.Fatalf("unable to add half circuit: %v", err) + } + + assertHasCircuit(t, circuitMap, circuit3) + cdb, circuitMap = restartCircuitMap(t, cdb) + assertHasCircuit(t, circuitMap, circuit3) // Add another circuit with an already-used HTLC ID but different // destination channel. - circuitMap.Add(&htlcswitch.PaymentCircuit{ - PaymentHash: hash3, - IncomingChanID: chan1, - IncomingHTLCID: 2, - OutgoingChanID: chan2, - OutgoingHTLCID: 0, - }) - - circuit = circuitMap.LookupByHTLC(chan1, 0) - if circuit == nil { - t.Fatal("LookupByHTLC failed to find circuit") + keystone3 := htlcswitch.Keystone{ + InKey: circuit3.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 0, + }, } - if circuit.PaymentHash != hash1 || circuit.IncomingHTLCID != 1 { - t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit) + circuit3.Outgoing = &keystone3.OutKey + if err := circuitMap.OpenCircuits(keystone3); err != nil { + t.Fatalf("unable to add full circuit: %v", err) } - circuit = circuitMap.LookupByHTLC(chan1, 1) - if circuit == nil { - t.Fatal("LookupByHTLC failed to find circuit") - } - if circuit.PaymentHash != hash2 || circuit.IncomingHTLCID != 2 { - t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit) - } + // Check that all have been marked as full circuits, and that no half + // circuits are currently being tracked. + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) + assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3) + cdb, circuitMap = restartCircuitMap(t, cdb) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) + assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3) - circuit = circuitMap.LookupByHTLC(chan2, 0) - if circuit == nil { - t.Fatal("LookupByHTLC failed to find circuit") + // Even though a circuit was added with chan1, HTLC ID 2 as the source, + // the lookup should go by destination channel, HTLC ID. + invalidKeystone := htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 2, } - if circuit.PaymentHash != hash3 || circuit.IncomingHTLCID != 2 { - t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit) - } - - // Even though a circuit was added with chan1, HTLC ID 2 as the source, the - // lookup should go by destination channel, HTLC ID. - circuit = circuitMap.LookupByHTLC(chan1, 2) + circuit = circuitMap.LookupOpenCircuit(invalidKeystone) if circuit != nil { t.Fatalf("LookupByHTLC returned a circuit without being added: %v", circuit) } + circuit4 := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 3, + }, + PaymentHash: hash1, + ErrorEncrypter: htlcswitch.NewMockObfuscator(), + } + if _, err := circuitMap.CommitCircuits(circuit4); err != nil { + t.Fatalf("unable to add half circuit: %v", err) + } + + // Circuit map should still only show one circuit with hash1, since we + // have not set the keystone for circuit4. + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuit(t, circuitMap, circuit4) + + cdb, circuitMap = restartCircuitMap(t, cdb) + + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuit(t, circuitMap, circuit4) + // Add a circuit with a destination channel and payment hash that are // already added but a different HTLC ID. - circuitMap.Add(&htlcswitch.PaymentCircuit{ - PaymentHash: hash1, - IncomingChanID: chan2, - IncomingHTLCID: 3, - OutgoingChanID: chan1, - OutgoingHTLCID: 3, - }) - - circuit = circuitMap.LookupByHTLC(chan1, 3) - if circuit == nil { - t.Fatal("LookupByHTLC failed to find circuit") + keystone4 := htlcswitch.Keystone{ + InKey: circuit4.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, } - if circuit.PaymentHash != hash1 || circuit.IncomingHTLCID != 3 { - t.Fatalf("LookupByHTLC found unexpected circuit: %v", circuit) + circuit4.Outgoing = &keystone4.OutKey + if err := circuitMap.OpenCircuits(keystone4); err != nil { + t.Fatalf("unable to add full circuit: %v", err) } - // Check lookups by payment hash. - circuits := circuitMap.LookupByPaymentHash(hash1) - if len(circuits) != 2 { - t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+ - "hash1: expected %d, got %d", 2, len(circuits)) - } + // Verify that all circuits have been fully added. + assertHasCircuit(t, circuitMap, circuit1) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + assertHasCircuit(t, circuitMap, circuit2) + assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) + assertHasCircuit(t, circuitMap, circuit3) + assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3) + assertHasCircuit(t, circuitMap, circuit4) + assertHasKeystone(t, circuitMap, keystone4.OutKey, circuit4) - circuits = circuitMap.LookupByPaymentHash(hash2) - if len(circuits) != 1 { - t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+ - "hash2: expected %d, got %d", 1, len(circuits)) - } + // Verify that each circuit is exposed via the proper hash bucketing. + assertNumCircuitsWithHash(t, circuitMap, hash1, 2) + assertHasCircuitForHash(t, circuitMap, hash1, circuit1) + assertHasCircuitForHash(t, circuitMap, hash1, circuit4) + + assertNumCircuitsWithHash(t, circuitMap, hash2, 1) + assertHasCircuitForHash(t, circuitMap, hash2, circuit2) + + assertNumCircuitsWithHash(t, circuitMap, hash3, 1) + assertHasCircuitForHash(t, circuitMap, hash3, circuit3) + + // Restart, then run checks again. + cdb, circuitMap = restartCircuitMap(t, cdb) + + // Verify that all circuits have been fully added. + assertHasCircuit(t, circuitMap, circuit1) + assertHasKeystone(t, circuitMap, keystone1.OutKey, circuit1) + assertHasCircuit(t, circuitMap, circuit2) + assertHasKeystone(t, circuitMap, keystone2.OutKey, circuit2) + assertHasCircuit(t, circuitMap, circuit3) + assertHasKeystone(t, circuitMap, keystone3.OutKey, circuit3) + assertHasCircuit(t, circuitMap, circuit4) + assertHasKeystone(t, circuitMap, keystone4.OutKey, circuit4) + + // Verify that each circuit is exposed via the proper hash bucketing. + assertNumCircuitsWithHash(t, circuitMap, hash1, 2) + assertHasCircuitForHash(t, circuitMap, hash1, circuit1) + assertHasCircuitForHash(t, circuitMap, hash1, circuit4) + + assertNumCircuitsWithHash(t, circuitMap, hash2, 1) + assertHasCircuitForHash(t, circuitMap, hash2, circuit2) + + assertNumCircuitsWithHash(t, circuitMap, hash3, 1) + assertHasCircuitForHash(t, circuitMap, hash3, circuit3) // Test removing circuits and the subsequent lookups. - err := circuitMap.Remove(chan1, 0) + err = circuitMap.DeleteCircuits(circuit1.Incoming) if err != nil { t.Fatalf("Remove returned unexpected error: %v", err) } - circuits = circuitMap.LookupByPaymentHash(hash1) - if len(circuits) != 1 { - t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+ - "hash1: expecected %d, got %d", 1, len(circuits)) - } - if circuits[0].OutgoingHTLCID != 3 { - t.Fatalf("LookupByPaymentHash returned wrong circuit for hash1: %v", - circuits[0]) - } + // There should be exactly one remaining circuit with hash1, and it + // should be circuit4. + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuitForHash(t, circuitMap, hash1, circuit4) + cdb, circuitMap = restartCircuitMap(t, cdb) + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuitForHash(t, circuitMap, hash1, circuit4) // Removing already-removed circuit should return an error. - err = circuitMap.Remove(chan1, 0) + err = circuitMap.DeleteCircuits(circuit1.Incoming) if err == nil { t.Fatal("Remove did not return expected not found error") } + // Verify that nothing related to hash1 has changed + assertNumCircuitsWithHash(t, circuitMap, hash1, 1) + assertHasCircuitForHash(t, circuitMap, hash1, circuit4) + // Remove last remaining circuit with payment hash hash1. - err = circuitMap.Remove(chan1, 3) + err = circuitMap.DeleteCircuits(circuit4.Incoming) if err != nil { t.Fatalf("Remove returned unexpected error: %v", err) } - circuits = circuitMap.LookupByPaymentHash(hash1) - if len(circuits) != 0 { - t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+ - "hash1: expecected %d, got %d", 0, len(circuits)) + assertNumCircuitsWithHash(t, circuitMap, hash1, 0) + assertNumCircuitsWithHash(t, circuitMap, hash2, 1) + assertNumCircuitsWithHash(t, circuitMap, hash3, 1) + cdb, circuitMap = restartCircuitMap(t, cdb) + assertNumCircuitsWithHash(t, circuitMap, hash1, 0) + assertNumCircuitsWithHash(t, circuitMap, hash2, 1) + assertNumCircuitsWithHash(t, circuitMap, hash3, 1) + + // Remove last remaining circuit with payment hash hash2. + err = circuitMap.DeleteCircuits(circuit2.Incoming) + if err != nil { + t.Fatalf("Remove returned unexpected error: %v", err) + } + + // There should now only be one remaining circuit, with hash3. + assertNumCircuitsWithHash(t, circuitMap, hash2, 0) + assertNumCircuitsWithHash(t, circuitMap, hash3, 1) + cdb, circuitMap = restartCircuitMap(t, cdb) + assertNumCircuitsWithHash(t, circuitMap, hash2, 0) + assertNumCircuitsWithHash(t, circuitMap, hash3, 1) + + // Remove last remaining circuit with payment hash hash3. + err = circuitMap.DeleteCircuits(circuit3.Incoming) + if err != nil { + t.Fatalf("Remove returned unexpected error: %v", err) + } + + // Check that the circuit map is empty, even after restarting. + assertNumCircuitsWithHash(t, circuitMap, hash3, 0) + cdb, circuitMap = restartCircuitMap(t, cdb) + assertNumCircuitsWithHash(t, circuitMap, hash3, 0) +} + +// assertHasKeystone tests that the circuit map contains the provided payment +// circuit. +func assertHasKeystone(t *testing.T, cm htlcswitch.CircuitMap, + outKey htlcswitch.CircuitKey, c *htlcswitch.PaymentCircuit) { + + circuit := cm.LookupOpenCircuit(outKey) + if !equalIgnoreLFD(circuit, c) { + t.Fatalf("unexpected circuit, want: %v, got %v", c, circuit) + } +} + +// assertDoesNotHaveKeystone tests that the circuit map does not contain a +// circuit for the provided outgoing circuit key. +func assertDoesNotHaveKeystone(t *testing.T, cm htlcswitch.CircuitMap, + outKey htlcswitch.CircuitKey) { + + circuit := cm.LookupOpenCircuit(outKey) + if circuit != nil { + t.Fatalf("expected no circuit for keystone %s, found %v", + outKey, circuit) + } +} + +// assertHasCircuitForHash tests that the provided circuit appears in the list +// of circuits for the given hash. +func assertHasCircuitForHash(t *testing.T, cm htlcswitch.CircuitMap, hash [32]byte, + circuit *htlcswitch.PaymentCircuit) { + + circuits := cm.LookupByPaymentHash(hash) + for _, c := range circuits { + if equalIgnoreLFD(c, circuit) { + return + } + } + + t.Fatalf("unable to find circuit: %v by hash: %v", circuit, hash) +} + +// assertNumCircuitsWithHash tests that the circuit has the right number of full +// circuits, indexed by the given hash. +func assertNumCircuitsWithHash(t *testing.T, cm htlcswitch.CircuitMap, + hash [32]byte, expectedNum int) { + + circuits := cm.LookupByPaymentHash(hash) + if len(circuits) != expectedNum { + t.Fatalf("LookupByPaymentHash returned wrong number of circuits for "+ + "hash=%v: expecected %d, got %d", hash, expectedNum, + len(circuits)) + } +} + +// assertHasCircuit queries the circuit map using the half-circuit's half +// key, and fails if the returned half-circuit differs from the provided one. +func assertHasCircuit(t *testing.T, cm htlcswitch.CircuitMap, + c *htlcswitch.PaymentCircuit) { + + c2 := cm.LookupCircuit(c.Incoming) + if !equalIgnoreLFD(c, c2) { + t.Fatalf("expected circuit: %v, got %v", c, c2) + } +} + +// equalIgnoreLFD compares two payment circuits, but ignores the current value +// of LoadedFromDisk. The value is temporarily set to false for the comparison +// and then restored. +func equalIgnoreLFD(c, c2 *htlcswitch.PaymentCircuit) bool { + ogLFD := c.LoadedFromDisk + ogLFD2 := c2.LoadedFromDisk + + c.LoadedFromDisk = false + c2.LoadedFromDisk = false + + isEqual := reflect.DeepEqual(c, c2) + + c.LoadedFromDisk = ogLFD + c2.LoadedFromDisk = ogLFD2 + + return isEqual +} + +// assertDoesNotHaveCircuit queries the circuit map using the circuit's +// incoming circuit key, and fails if it is found. +func assertDoesNotHaveCircuit(t *testing.T, cm htlcswitch.CircuitMap, + c *htlcswitch.PaymentCircuit) { + + c2 := cm.LookupCircuit(c.Incoming) + if c2 != nil { + t.Fatalf("expected no circuit for %v, got %v", c, c2) + } +} + +// makeCircuitDB initializes a new test channeldb for testing the persistence of +// the circuit map. If an empty string is provided as a path, a temp directory +// will be created. +func makeCircuitDB(t *testing.T, path string) *channeldb.DB { + if path == "" { + var err error + path, err = ioutil.TempDir("", "circuitdb") + if err != nil { + t.Fatalf("unable to create temp path: %v", err) + } + } + + db, err := channeldb.Open(path) + if err != nil { + t.Fatalf("unable to open channel db: %v", err) + } + + return db +} + +// Creates a new circuit map, backed by a freshly opened channeldb. The existing +// channeldb is closed in order to simulate a complete restart. +func restartCircuitMap(t *testing.T, cdb *channeldb.DB) (*channeldb.DB, + htlcswitch.CircuitMap) { + + // Record the current temp path and close current db. + dbPath := cdb.Path() + cdb.Close() + + // Reinitialize circuit map with same db path. + cdb2 := makeCircuitDB(t, dbPath) + cm2, err := htlcswitch.NewCircuitMap(cdb2) + if err != nil { + t.Fatalf("unable to recreate persistent circuit map: %v", err) + } + + return cdb2, cm2 +} + +// TestCircuitMapCommitCircuits tests the following behavior of CommitCircuits: +// 1. New circuits are successfully added. +// 2. Duplicate circuits are dropped anytime before circuit map shutsdown. +// 3. Duplicate circuits are failed anytime after circuit map restarts. +func TestCircuitMapCommitCircuits(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + circuit := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + actions, err := circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + if len(actions.Drops) > 0 { + t.Fatalf("new circuit should not have been dropped") + } + if len(actions.Fails) > 0 { + t.Fatalf("new circuit should not have failed") + } + if len(actions.Adds) != 1 { + t.Fatalf("only one circuit should have been added, found %d", + len(actions.Adds)) + } + + circuit2 := circuitMap.LookupCircuit(circuit.Incoming) + if !reflect.DeepEqual(circuit, circuit2) { + t.Fatalf("unexpected committed circuit: got %v, want %v", + circuit2, circuit) + } + + // Then we will try to readd the same circuit again, this should result + // in the circuit being dropped. This can happen if the incoming link + // flaps. + actions, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + if len(actions.Adds) > 0 { + t.Fatalf("duplicate circuit should not have been added to circuit map") + } + if len(actions.Fails) > 0 { + t.Fatalf("duplicate circuit should not have failed") + } + if len(actions.Drops) != 1 { + t.Fatalf("only one circuit should have been dropped, found %d", + len(actions.Drops)) + } + + // Finally, restart the circuit map, which will cause the added circuit + // to be loaded from disk. Since the keystone was never set, subsequent + // attempts to commit the circuit should cause the circuit map to + // indicate that that the HTLC should be failed back. + cdb, circuitMap = restartCircuitMap(t, cdb) + + actions, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + if len(actions.Adds) > 0 { + t.Fatalf("duplicate circuit with incomplete forwarding " + + "decision should not have been added to circuit map") + } + if len(actions.Drops) > 0 { + t.Fatalf("duplicate circuit with incomplete forwarding " + + "decision should not have been dropped by circuit map") + } + if len(actions.Fails) != 1 { + t.Fatalf("only one duplicate circuit with incomplete "+ + "forwarding decision should have been failed, found: "+ + "%d", len(actions.Fails)) + } + + // Lookup the committed circuit again, it should be identical apart from + // the loaded from disk flag. + circuit2 = circuitMap.LookupCircuit(circuit.Incoming) + if !equalIgnoreLFD(circuit, circuit2) { + t.Fatalf("unexpected committed circuit: got %v, want %v", + circuit2, circuit) + } +} + +// TestCircuitMapOpenCircuits checks that circuits are properly opened, and that +// duplicate attempts to open a circuit will result in an error. +func TestCircuitMapOpenCircuits(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + chan2 = lnwire.NewShortChanIDFromInt(2) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + circuit := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + _, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + + keystone := htlcswitch.Keystone{ + InKey: circuit.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 2, + }, + } + + // Open the circuit for the first time. + err = circuitMap.OpenCircuits(keystone) + if err != nil { + t.Fatalf("failed to open circuits: %v", err) + } + + // Check that we can retrieve the open circuit if the circuit map before + // the circuit map is restarted. + circuit2 := circuitMap.LookupOpenCircuit(keystone.OutKey) + if !reflect.DeepEqual(circuit, circuit2) { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, circuit) + } + + if !circuit2.HasKeystone() { + t.Fatalf("open circuit should have keystone") + } + if !reflect.DeepEqual(&keystone.OutKey, circuit2.Outgoing) { + t.Fatalf("expected open circuit to have outgoing key: %v, found %v", + &keystone.OutKey, circuit2.Outgoing) + } + + // Open the circuit for a second time, which should fail due to a + // duplicate keystone + err = circuitMap.OpenCircuits(keystone) + if err != htlcswitch.ErrDuplicateKeystone { + t.Fatalf("failed to open circuits: %v", err) + } + + // Then we will try to readd the same circuit again, this should result + // in the circuit being dropped. This can happen if the incoming link + // flaps OR the switch is entirely restarted and the outgoing link has + // not received a response. + actions, err := circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + if len(actions.Adds) > 0 { + t.Fatalf("duplicate circuit should not have been added to circuit map") + } + if len(actions.Fails) > 0 { + t.Fatalf("duplicate circuit should not have failed") + } + if len(actions.Drops) != 1 { + t.Fatalf("only one circuit should have been dropped, found %d", + len(actions.Drops)) + } + + // Now, restart the circuit map, which will cause the opened circuit to + // be loaded from disk. Since we set the keystone on this circuit, it + // should be restored as such in memory. + // + // NOTE: The channel db doesn't have any channel data, so no keystones + // will be trimmed. + cdb, circuitMap = restartCircuitMap(t, cdb) + + // Check that we can still query for the open circuit. + circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey) + if !equalIgnoreLFD(circuit, circuit2) { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, circuit) + } + + // Try to open the circuit again, we expect this to fail since the open + // circuit was restored. + err = circuitMap.OpenCircuits(keystone) + if err != htlcswitch.ErrDuplicateKeystone { + t.Fatalf("failed to open circuits: %v", err) + } + + // Lastly, with the circuit map restarted, try one more time to recommit + // the open circuit. This should be dropped, and is expected to happen + // if the incoming link flaps OR the switch is entirely restarted and + // the outgoing link has not received a response. + actions, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + if len(actions.Adds) > 0 { + t.Fatalf("duplicate circuit should not have been added to circuit map") + } + if len(actions.Fails) > 0 { + t.Fatalf("duplicate circuit should not have failed") + } + if len(actions.Drops) != 1 { + t.Fatalf("only one circuit should have been dropped, found %d", + len(actions.Drops)) + } +} + +func assertCircuitsOpenedPreRestart(t *testing.T, + circuitMap htlcswitch.CircuitMap, + circuits []*htlcswitch.PaymentCircuit, + keystones []htlcswitch.Keystone) { + + for i, circuit := range circuits { + keystone := keystones[i] + + openCircuit := circuitMap.LookupOpenCircuit(keystone.OutKey) + if !reflect.DeepEqual(circuit, openCircuit) { + t.Fatalf("unexpected open circuit %d: got %v, want %v", + i, openCircuit, circuit) + } + + if !openCircuit.HasKeystone() { + t.Fatalf("open circuit %d should have keystone", i) + } + if !reflect.DeepEqual(&keystone.OutKey, openCircuit.Outgoing) { + t.Fatalf("expected open circuit %d to have outgoing "+ + "key: %v, found %v", i, + &keystone.OutKey, openCircuit.Outgoing) + } + } +} + +func assertCircuitsOpenedPostRestart(t *testing.T, + circuitMap htlcswitch.CircuitMap, + circuits []*htlcswitch.PaymentCircuit, + keystones []htlcswitch.Keystone) { + + for i, circuit := range circuits { + keystone := keystones[i] + + openCircuit := circuitMap.LookupOpenCircuit(keystone.OutKey) + if !equalIgnoreLFD(circuit, openCircuit) { + t.Fatalf("unexpected open circuit %d: got %v, want %v", + i, openCircuit, circuit) + } + + if !openCircuit.HasKeystone() { + t.Fatalf("open circuit %d should have keystone", i) + } + if !reflect.DeepEqual(&keystone.OutKey, openCircuit.Outgoing) { + t.Fatalf("expected open circuit %d to have outgoing "+ + "key: %v, found %v", i, + &keystone.OutKey, openCircuit.Outgoing) + } + } +} + +func assertCircuitsNotOpenedPreRestart(t *testing.T, + circuitMap htlcswitch.CircuitMap, + circuits []*htlcswitch.PaymentCircuit, + keystones []htlcswitch.Keystone, + offset int) { + + for i := range circuits { + keystone := keystones[i] + + openCircuit := circuitMap.LookupOpenCircuit(keystone.OutKey) + if openCircuit != nil { + t.Fatalf("expected circuit %d not to be open", + offset+i) + } + + circuit := circuitMap.LookupCircuit(keystone.InKey) + if circuit == nil { + t.Fatalf("expected to find unopened circuit %d", + offset+i) + } + if circuit.HasKeystone() { + t.Fatalf("circuit %d should not have keystone", + offset+i) + } + } +} + +// TestCircuitMapTrimOpenCircuits verifies that the circuit map properly removes +// circuits from disk and the in-memory state when TrimOpenCircuits is used. +// This test checks that a successful trim survives a restart, and that circuits +// added before the restart can also be trimmed. +func TestCircuitMapTrimOpenCircuits(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + chan2 = lnwire.NewShortChanIDFromInt(2) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + const nCircuits = 10 + const firstTrimIndex = 7 + const secondTrimIndex = 3 + + // Create a list of all circuits that will be committed in the circuit + // map. The incoming HtlcIDs are chosen so that there is overlap with + // the outgoing HtlcIDs, but ensures that the test is not dependent on + // them being equal. + circuits := make([]*htlcswitch.PaymentCircuit, nCircuits) + for i := range circuits { + circuits[i] = &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: uint64(i + 3), + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + _, err = circuitMap.CommitCircuits(circuits...) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + + // Now create a list of the keystones that we will use to preemptively + // open the circuits. We set the index as the outgoing HtlcID to i + // simplify the indexing logic of the test. + keystones := make([]htlcswitch.Keystone, nCircuits) + for i := range keystones { + keystones[i] = htlcswitch.Keystone{ + InKey: circuits[i].Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: uint64(i), + }, + } + } + + // Open the circuits for the first time. + err = circuitMap.OpenCircuits(keystones...) + if err != nil { + t.Fatalf("failed to open circuits: %v", err) + } + + // Check that all circuits are marked open. + assertCircuitsOpenedPreRestart(t, circuitMap, circuits, keystones) + + // Now trim up above outgoing htlcid `firstTrimIndex` (7). This should + // leave the first 7 circuits open, and the rest should be reverted to + // an unopened state. + err = circuitMap.TrimOpenCircuits(chan2, firstTrimIndex) + if err != nil { + t.Fatalf("unable to trim circuits") + } + + assertCircuitsOpenedPreRestart(t, + circuitMap, + circuits[:firstTrimIndex], + keystones[:firstTrimIndex], + ) + + assertCircuitsNotOpenedPreRestart( + t, + circuitMap, + circuits[firstTrimIndex:], + keystones[firstTrimIndex:], + firstTrimIndex, + ) + + // Restart the circuit map, verify that that the trim is reflected on + // startup. + cdb, circuitMap = restartCircuitMap(t, cdb) + + assertCircuitsOpenedPostRestart( + t, + circuitMap, + circuits[:firstTrimIndex], + keystones[:firstTrimIndex], + ) + + assertCircuitsNotOpenedPreRestart( + t, + circuitMap, + circuits[firstTrimIndex:], + keystones[firstTrimIndex:], + firstTrimIndex, + ) + + // Now, trim above outgoing htlcid `secondTrimIndex` (3). Only the first + // three circuits should be open, with any others being reverted back to + // unopened. + err = circuitMap.TrimOpenCircuits(chan2, secondTrimIndex) + if err != nil { + t.Fatalf("unable to trim circuits") + } + + assertCircuitsOpenedPostRestart( + t, + circuitMap, + circuits[:secondTrimIndex], + keystones[:secondTrimIndex], + ) + + assertCircuitsNotOpenedPreRestart( + t, + circuitMap, + circuits[secondTrimIndex:], + keystones[secondTrimIndex:], + secondTrimIndex, + ) + + // Restart the circuit map one last time to make sure the changes are + // persisted. + cdb, circuitMap = restartCircuitMap(t, cdb) + + assertCircuitsOpenedPostRestart( + t, + circuitMap, + circuits[:secondTrimIndex], + keystones[:secondTrimIndex], + ) + + assertCircuitsNotOpenedPreRestart( + t, + circuitMap, + circuits[secondTrimIndex:], + keystones[secondTrimIndex:], + secondTrimIndex, + ) +} + +// TestCircuitMapCloseOpenCircuits asserts that the circuit map can properly +// close open circuits, and that it allows at most one response to do so +// successfully. It also checks that a circuit is reopened if the close was not +// persisted via DeleteCircuits, and can again be closed. +func TestCircuitMapCloseOpenCircuits(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + chan2 = lnwire.NewShortChanIDFromInt(2) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + circuit := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + _, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + + keystone := htlcswitch.Keystone{ + InKey: circuit.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 2, + }, + } + + // Open the circuit for the first time. + err = circuitMap.OpenCircuits(keystone) + if err != nil { + t.Fatalf("failed to open circuits: %v", err) + } + + // Check that we can retrieve the open circuit if the circuit map before + // the circuit map is restarted. + circuit2 := circuitMap.LookupOpenCircuit(keystone.OutKey) + if !reflect.DeepEqual(circuit, circuit2) { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, circuit) + } + + // Open the circuit for a second time, which should fail due to a + // duplicate keystone + err = circuitMap.OpenCircuits(keystone) + if err != htlcswitch.ErrDuplicateKeystone { + t.Fatalf("failed to open circuits: %v", err) + } + + // Close the open circuit for the first time, which should succeed. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Closing the circuit a second time should result in a failure. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != htlcswitch.ErrCircuitClosing { + t.Fatalf("unable to close unopened circuit") + } + + // Now, restart the circuit map, which will cause the opened circuit to + // be loaded from disk. Since we set the keystone on this circuit, it + // should be restored as such in memory. + // + // NOTE: The channel db doesn't have any channel data, so no keystones + // will be trimmed. + cdb, circuitMap = restartCircuitMap(t, cdb) + + // Close the open circuit for the first time, which should succeed. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Closing the circuit a second time should result in a failure. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != htlcswitch.ErrCircuitClosing { + t.Fatalf("unable to close unopened circuit") + } +} + +// TestCircuitMapCloseUnopenedCircuit tests that closing an unopened circuit +// allows at most semantics, and that the close is not persisted across +// restarts. +func TestCircuitMapCloseUnopenedCircuit(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + circuit := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + _, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + + // Close the open circuit for the first time, which should succeed. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Closing the circuit a second time should result in a failure. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != htlcswitch.ErrCircuitClosing { + t.Fatalf("unable to close unopened circuit") + } + + // Now, restart the circuit map, which will result in the circuit being + // reopened, since no attempt to delete the circuit was made. + cdb, circuitMap = restartCircuitMap(t, cdb) + + // Close the open circuit for the first time, which should succeed. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Closing the circuit a second time should result in a failure. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != htlcswitch.ErrCircuitClosing { + t.Fatalf("unable to close unopened circuit") + } +} + +// TestCircuitMapDeleteUnopenedCircuit checks that an unopened circuit can be +// removed persistently from the circuit map. +func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + circuit := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + _, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + + // Close the open circuit for the first time, which should succeed. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + err = circuitMap.DeleteCircuits(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Check that we can retrieve the open circuit if the circuit map before + // the circuit map is restarted. + circuit2 := circuitMap.LookupCircuit(circuit.Incoming) + if circuit2 != nil { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, nil) + } + + // Now, restart the circuit map, and check that the deletion survived + // the restart. + cdb, circuitMap = restartCircuitMap(t, cdb) + + circuit2 = circuitMap.LookupCircuit(circuit.Incoming) + if circuit2 != nil { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, nil) + } +} + +// TestCircuitMapDeleteUnopenedCircuit checks that an open circuit can be +// removed persistently from the circuit map. +func TestCircuitMapDeleteOpenCircuit(t *testing.T) { + t.Parallel() + + var ( + chan1 = lnwire.NewShortChanIDFromInt(1) + chan2 = lnwire.NewShortChanIDFromInt(2) + circuitMap htlcswitch.CircuitMap + err error + ) + + cdb := makeCircuitDB(t, "") + circuitMap, err = htlcswitch.NewCircuitMap(cdb) + if err != nil { + t.Fatalf("unable to create persistent circuit map: %v", err) + } + + circuit := &htlcswitch.PaymentCircuit{ + Incoming: htlcswitch.CircuitKey{ + ChanID: chan1, + HtlcID: 3, + }, + ErrorEncrypter: htlcswitch.NewSphinxErrorEncrypter(), + } + + // First we will try to add an new circuit to the circuit map, this + // should succeed. + _, err = circuitMap.CommitCircuits(circuit) + if err != nil { + t.Fatalf("failed to commit circuits: %v", err) + } + + keystone := htlcswitch.Keystone{ + InKey: circuit.Incoming, + OutKey: htlcswitch.CircuitKey{ + ChanID: chan2, + HtlcID: 2, + }, + } + + // Open the circuit for the first time. + err = circuitMap.OpenCircuits(keystone) + if err != nil { + t.Fatalf("failed to open circuits: %v", err) + } + + // Close the open circuit for the first time, which should succeed. + _, err = circuitMap.FailCircuit(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Persistently remove the circuit identified by incoming chan id. + err = circuitMap.DeleteCircuits(circuit.Incoming) + if err != nil { + t.Fatalf("unable to close unopened circuit") + } + + // Check that we can no longer retrieve the open circuit. + circuit2 := circuitMap.LookupOpenCircuit(keystone.OutKey) + if circuit2 != nil { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, nil) + } + + // Now, restart the circuit map, and check that the deletion survived + // the restart. + cdb, circuitMap = restartCircuitMap(t, cdb) + + circuit2 = circuitMap.LookupOpenCircuit(keystone.OutKey) + if circuit2 != nil { + t.Fatalf("unexpected open circuit: got %v, want %v", + circuit2, nil) } }