diff --git a/htlcswitch/test_utils.go b/htlcswitch/test_utils.go index b79fe6f5..897d4448 100644 --- a/htlcswitch/test_utils.go +++ b/htlcswitch/test_utils.go @@ -585,9 +585,7 @@ type threeHopNetwork struct { carolServer *mockServer carolChannelLink *channelLink - feeEstimator *mockFeeEstimator - - globalPolicy ForwardingPolicy + hopNetwork } // generateHops creates the per hop payload, the total amount to be sent, and @@ -872,23 +870,23 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, bobDb := firstBobChannel.State().Db carolDb := carolChannel.State().Db - defaultDelta := uint32(6) + hopNetwork := newHopNetwork() // Create three peers/servers. aliceServer, err := newMockServer( - t, "alice", startingHeight, aliceDb, defaultDelta, + t, "alice", startingHeight, aliceDb, hopNetwork.defaultDelta, ) if err != nil { t.Fatalf("unable to create alice server: %v", err) } bobServer, err := newMockServer( - t, "bob", startingHeight, bobDb, defaultDelta, + t, "bob", startingHeight, bobDb, hopNetwork.defaultDelta, ) if err != nil { t.Fatalf("unable to create bob server: %v", err) } carolServer, err := newMockServer( - t, "carol", startingHeight, carolDb, defaultDelta, + t, "carol", startingHeight, carolDb, hopNetwork.defaultDelta, ) if err != nil { t.Fatalf("unable to create carol server: %v", err) @@ -900,17 +898,78 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, bobDecoder := newMockIteratorDecoder() carolDecoder := newMockIteratorDecoder() - feeEstimator := &mockFeeEstimator{ - byteFeeIn: make(chan lnwallet.SatPerKWeight), - quit: make(chan struct{}), + aliceChannelLink, err := hopNetwork.createChannelLink(aliceServer, + bobServer, aliceChannel, aliceDecoder, + ) + if err != nil { + t.Fatal(err) } - const ( - batchTimeout = 50 * time.Millisecond - fwdPkgTimeout = 15 * time.Second - minFeeUpdateTimeout = 30 * time.Minute - maxFeeUpdateTimeout = 40 * time.Minute - ) + firstBobChannelLink, err := hopNetwork.createChannelLink(bobServer, + aliceServer, firstBobChannel, bobDecoder) + if err != nil { + t.Fatal(err) + } + + secondBobChannelLink, err := hopNetwork.createChannelLink(bobServer, + carolServer, secondBobChannel, bobDecoder) + if err != nil { + t.Fatal(err) + } + + carolChannelLink, err := hopNetwork.createChannelLink(carolServer, + bobServer, carolChannel, carolDecoder) + if err != nil { + t.Fatal(err) + } + + return &threeHopNetwork{ + aliceServer: aliceServer, + aliceChannelLink: aliceChannelLink.(*channelLink), + + bobServer: bobServer, + firstBobChannelLink: firstBobChannelLink.(*channelLink), + secondBobChannelLink: secondBobChannelLink.(*channelLink), + + carolServer: carolServer, + carolChannelLink: carolChannelLink.(*channelLink), + + hopNetwork: *hopNetwork, + } +} + +// createTwoClusterChannels creates lightning channels which are needed for +// a 2 hop network cluster to be initialized. +func createTwoClusterChannels(aliceToBob, bobToCarol btcutil.Amount) ( + *lnwallet.LightningChannel, *lnwallet.LightningChannel, + func(), error) { + + _, _, firstChanID, _ := genIDs() + + // Create lightning channels between Alice<->Bob and Bob<->Carol + aliceChannel, firstBobChannel, cleanAliceBob, _, err := + createTestChannel(alicePrivKey, bobPrivKey, aliceToBob, + aliceToBob, 0, 0, firstChanID) + if err != nil { + return nil, nil, nil, errors.Errorf("unable to create "+ + "alice<->bob channel: %v", err) + } + + return aliceChannel, firstBobChannel, cleanAliceBob, nil +} + +// hopNetwork is the base struct for two and three hop networks +type hopNetwork struct { + feeEstimator *mockFeeEstimator + globalPolicy ForwardingPolicy + obfuscator ErrorEncrypter + pCache *mockPreimageCache + + defaultDelta uint32 +} + +func newHopNetwork() *hopNetwork { + defaultDelta := uint32(6) pCache := &mockPreimageCache{ // hash -> preimage @@ -924,190 +983,73 @@ func newThreeHopNetwork(t testing.TB, aliceChannel, firstBobChannel, } obfuscator := NewMockObfuscator() - aliceChannelLink := NewChannelLink( - ChannelLinkConfig{ - Switch: aliceServer.htlcSwitch, - FwrdingPolicy: globalPolicy, - Peer: bobServer, - Circuits: aliceServer.htlcSwitch.CircuitModifier(), - ForwardPackets: aliceServer.htlcSwitch.ForwardPackets, - DecodeHopIterators: aliceDecoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - ErrorEncrypter, lnwire.FailCode) { - return obfuscator, lnwire.CodeNone - }, - FetchLastChannelUpdate: mockGetChanUpdateMessage, - Registry: aliceServer.registry, - FeeEstimator: feeEstimator, - PreimageCache: pCache, - UpdateContractSignals: func(*contractcourt.ContractSignals) error { - return nil - }, - ChainEvents: &contractcourt.ChainEventSubscription{}, - SyncStates: true, - BatchSize: 10, - BatchTicker: ticker.MockNew(batchTimeout), - FwdPkgGCTicker: ticker.MockNew(fwdPkgTimeout), - MinFeeUpdateTimeout: minFeeUpdateTimeout, - MaxFeeUpdateTimeout: maxFeeUpdateTimeout, - OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, - }, - aliceChannel, - ) - if err := aliceServer.htlcSwitch.AddLink(aliceChannelLink); err != nil { - t.Fatalf("unable to add alice channel link: %v", err) + feeEstimator := &mockFeeEstimator{ + byteFeeIn: make(chan lnwallet.SatPerKWeight), + quit: make(chan struct{}), } - go func() { - for { - select { - case <-aliceChannelLink.(*channelLink).htlcUpdates: - case <-aliceChannelLink.(*channelLink).quit: - return - } - } - }() - - firstBobChannelLink := NewChannelLink( - ChannelLinkConfig{ - Switch: bobServer.htlcSwitch, - FwrdingPolicy: globalPolicy, - Peer: aliceServer, - Circuits: bobServer.htlcSwitch.CircuitModifier(), - ForwardPackets: bobServer.htlcSwitch.ForwardPackets, - DecodeHopIterators: bobDecoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - ErrorEncrypter, lnwire.FailCode) { - return obfuscator, lnwire.CodeNone - }, - FetchLastChannelUpdate: mockGetChanUpdateMessage, - Registry: bobServer.registry, - FeeEstimator: feeEstimator, - PreimageCache: pCache, - UpdateContractSignals: func(*contractcourt.ContractSignals) error { - return nil - }, - ChainEvents: &contractcourt.ChainEventSubscription{}, - SyncStates: true, - BatchSize: 10, - BatchTicker: ticker.MockNew(batchTimeout), - FwdPkgGCTicker: ticker.MockNew(fwdPkgTimeout), - MinFeeUpdateTimeout: minFeeUpdateTimeout, - MaxFeeUpdateTimeout: maxFeeUpdateTimeout, - OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, - }, - firstBobChannel, - ) - if err := bobServer.htlcSwitch.AddLink(firstBobChannelLink); err != nil { - t.Fatalf("unable to add first bob channel link: %v", err) - } - go func() { - for { - select { - case <-firstBobChannelLink.(*channelLink).htlcUpdates: - case <-firstBobChannelLink.(*channelLink).quit: - return - } - } - }() - - secondBobChannelLink := NewChannelLink( - ChannelLinkConfig{ - Switch: bobServer.htlcSwitch, - FwrdingPolicy: globalPolicy, - Peer: carolServer, - Circuits: bobServer.htlcSwitch.CircuitModifier(), - ForwardPackets: bobServer.htlcSwitch.ForwardPackets, - DecodeHopIterators: bobDecoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - ErrorEncrypter, lnwire.FailCode) { - return obfuscator, lnwire.CodeNone - }, - FetchLastChannelUpdate: mockGetChanUpdateMessage, - Registry: bobServer.registry, - FeeEstimator: feeEstimator, - PreimageCache: pCache, - UpdateContractSignals: func(*contractcourt.ContractSignals) error { - return nil - }, - ChainEvents: &contractcourt.ChainEventSubscription{}, - SyncStates: true, - BatchSize: 10, - BatchTicker: ticker.MockNew(batchTimeout), - FwdPkgGCTicker: ticker.MockNew(fwdPkgTimeout), - MinFeeUpdateTimeout: minFeeUpdateTimeout, - MaxFeeUpdateTimeout: maxFeeUpdateTimeout, - OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, - }, - secondBobChannel, - ) - if err := bobServer.htlcSwitch.AddLink(secondBobChannelLink); err != nil { - t.Fatalf("unable to add second bob channel link: %v", err) - } - go func() { - for { - select { - case <-secondBobChannelLink.(*channelLink).htlcUpdates: - case <-secondBobChannelLink.(*channelLink).quit: - return - } - } - }() - - carolChannelLink := NewChannelLink( - ChannelLinkConfig{ - Switch: carolServer.htlcSwitch, - FwrdingPolicy: globalPolicy, - Peer: bobServer, - Circuits: carolServer.htlcSwitch.CircuitModifier(), - ForwardPackets: carolServer.htlcSwitch.ForwardPackets, - DecodeHopIterators: carolDecoder.DecodeHopIterators, - ExtractErrorEncrypter: func(*btcec.PublicKey) ( - ErrorEncrypter, lnwire.FailCode) { - return obfuscator, lnwire.CodeNone - }, - FetchLastChannelUpdate: mockGetChanUpdateMessage, - Registry: carolServer.registry, - FeeEstimator: feeEstimator, - PreimageCache: pCache, - UpdateContractSignals: func(*contractcourt.ContractSignals) error { - return nil - }, - ChainEvents: &contractcourt.ChainEventSubscription{}, - SyncStates: true, - BatchSize: 10, - BatchTicker: ticker.MockNew(batchTimeout), - FwdPkgGCTicker: ticker.MockNew(fwdPkgTimeout), - MinFeeUpdateTimeout: minFeeUpdateTimeout, - MaxFeeUpdateTimeout: maxFeeUpdateTimeout, - OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, - }, - carolChannel, - ) - if err := carolServer.htlcSwitch.AddLink(carolChannelLink); err != nil { - t.Fatalf("unable to add carol channel link: %v", err) - } - go func() { - for { - select { - case <-carolChannelLink.(*channelLink).htlcUpdates: - case <-carolChannelLink.(*channelLink).quit: - return - } - } - }() - - return &threeHopNetwork{ - aliceServer: aliceServer, - aliceChannelLink: aliceChannelLink.(*channelLink), - - bobServer: bobServer, - firstBobChannelLink: firstBobChannelLink.(*channelLink), - secondBobChannelLink: secondBobChannelLink.(*channelLink), - - carolServer: carolServer, - carolChannelLink: carolChannelLink.(*channelLink), + return &hopNetwork{ feeEstimator: feeEstimator, globalPolicy: globalPolicy, + obfuscator: obfuscator, + pCache: pCache, + defaultDelta: defaultDelta, } } + +func (h *hopNetwork) createChannelLink(server, peer *mockServer, + channel *lnwallet.LightningChannel, + decoder *mockIteratorDecoder) (ChannelLink, error) { + + const ( + batchTimeout = 50 * time.Millisecond + fwdPkgTimeout = 15 * time.Second + minFeeUpdateTimeout = 30 * time.Minute + maxFeeUpdateTimeout = 40 * time.Minute + ) + + link := NewChannelLink( + ChannelLinkConfig{ + Switch: server.htlcSwitch, + FwrdingPolicy: h.globalPolicy, + Peer: peer, + Circuits: server.htlcSwitch.CircuitModifier(), + ForwardPackets: server.htlcSwitch.ForwardPackets, + DecodeHopIterators: decoder.DecodeHopIterators, + ExtractErrorEncrypter: func(*btcec.PublicKey) ( + ErrorEncrypter, lnwire.FailCode) { + return h.obfuscator, lnwire.CodeNone + }, + FetchLastChannelUpdate: mockGetChanUpdateMessage, + Registry: server.registry, + FeeEstimator: h.feeEstimator, + PreimageCache: h.pCache, + UpdateContractSignals: func(*contractcourt.ContractSignals) error { + return nil + }, + ChainEvents: &contractcourt.ChainEventSubscription{}, + SyncStates: true, + BatchSize: 10, + BatchTicker: ticker.MockNew(batchTimeout), + FwdPkgGCTicker: ticker.MockNew(fwdPkgTimeout), + MinFeeUpdateTimeout: minFeeUpdateTimeout, + MaxFeeUpdateTimeout: maxFeeUpdateTimeout, + OnChannelFailure: func(lnwire.ChannelID, lnwire.ShortChannelID, LinkFailureError) {}, + }, + channel, + ) + if err := server.htlcSwitch.AddLink(link); err != nil { + return nil, fmt.Errorf("unable to add channel link: %v", err) + } + go func() { + for { + select { + case <-link.(*channelLink).htlcUpdates: + case <-link.(*channelLink).quit: + return + } + } + }() + + return link, nil +}