package chanacceptor import ( "errors" "math/big" "testing" "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnwallet/chancloser" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const testTimeout = time.Second type channelAcceptorCtx struct { t *testing.T // extRequests is the channel that we send our channel accept requests // into, this channel mocks sending of a request to the rpc acceptor. // This channel should be buffered with the number of requests we want // to send so that it does not block (like a rpc stream). extRequests chan []byte // responses is a map of pending channel IDs to the response which we // wish to mock the remote channel acceptor sending. responses map[[32]byte]*lnrpc.ChannelAcceptResponse // acceptor is the channel acceptor we create for the test. acceptor *RPCAcceptor // errChan is a channel that the error the channel acceptor exits with // is sent into. errChan chan error // quit is a channel that can be used to shutdown the channel acceptor // and return errShuttingDown. quit chan struct{} } func newChanAcceptorCtx(t *testing.T, acceptCallCount int, responses map[[32]byte]*lnrpc.ChannelAcceptResponse) *channelAcceptorCtx { testCtx := &channelAcceptorCtx{ t: t, extRequests: make(chan []byte, acceptCallCount), responses: responses, errChan: make(chan error), quit: make(chan struct{}), } testCtx.acceptor = NewRPCAcceptor( testCtx.receiveResponse, testCtx.sendRequest, testTimeout*5, &chaincfg.TestNet3Params, testCtx.quit, ) return testCtx } // sendRequest mocks sending a request to the channel acceptor. func (c *channelAcceptorCtx) sendRequest(request *lnrpc.ChannelAcceptRequest) error { select { case c.extRequests <- request.PendingChanId: case <-time.After(testTimeout): c.t.Fatalf("timeout sending request: %v", request.PendingChanId) } return nil } // receiveResponse mocks sending of a response from the channel acceptor. func (c *channelAcceptorCtx) receiveResponse() (*lnrpc.ChannelAcceptResponse, error) { select { case id := <-c.extRequests: scratch := [32]byte{} copy(scratch[:], id) resp, ok := c.responses[scratch] assert.True(c.t, ok) return resp, nil case <-time.After(testTimeout): c.t.Fatalf("timeout receiving request") return nil, errors.New("receiveResponse timeout") // Exit if our test acceptor closes the done channel, which indicates // that the acceptor is shutting down. case <-c.acceptor.done: return nil, errors.New("acceptor shutting down") } } // start runs our channel acceptor in a goroutine which sends its exit error // into our test error channel. func (c *channelAcceptorCtx) start() { go func() { c.errChan <- c.acceptor.Run() }() } // stop shuts down the test's channel acceptor and asserts that it exits with // our expected error. func (c *channelAcceptorCtx) stop() { close(c.quit) select { case actual := <-c.errChan: assert.Equal(c.t, errShuttingDown, actual) case <-time.After(testTimeout): c.t.Fatal("timeout waiting for acceptor to exit") } } // queryAndAssert takes a map of open channel requests which we want to call // Accept for to the outcome we expect from the acceptor, dispatches each // request in a goroutine and then asserts that we get the outcome we expect. func (c *channelAcceptorCtx) queryAndAssert(queries map[*lnwire.OpenChannel]*ChannelAcceptResponse) { var ( node = &btcec.PublicKey{ X: big.NewInt(1), Y: big.NewInt(1), } responses = make(chan struct{}) ) for request, expected := range queries { request := request expected := expected go func() { resp := c.acceptor.Accept(&ChannelAcceptRequest{ Node: node, OpenChanMsg: request, }) assert.Equal(c.t, expected, resp) responses <- struct{}{} }() } // Wait for each of our requests to return a response before we exit. for i := 0; i < len(queries); i++ { select { case <-responses: case <-time.After(testTimeout): c.t.Fatalf("did not receive response") } } } // TestMultipleAcceptClients tests that the RPC acceptor is capable of handling // multiple requests to its Accept function and responding to them correctly. func TestMultipleAcceptClients(t *testing.T) { testAddr := "bcrt1qwrmq9uca0t3dy9t9wtuq5tm4405r7tfzyqn9pp" testUpfront, err := chancloser.ParseUpfrontShutdownAddress( testAddr, &chaincfg.TestNet3Params, ) require.NoError(t, err) var ( chan1 = &lnwire.OpenChannel{ PendingChannelID: [32]byte{1}, } chan2 = &lnwire.OpenChannel{ PendingChannelID: [32]byte{2}, } chan3 = &lnwire.OpenChannel{ PendingChannelID: [32]byte{3}, } customError = errors.New("go away") // Queries is a map of the channel IDs we will query Accept // with, and the set of outcomes we expect. queries = map[*lnwire.OpenChannel]*ChannelAcceptResponse{ chan1: NewChannelAcceptResponse( true, nil, testUpfront, 1, 2, 3, 4, 5, 6, ), chan2: NewChannelAcceptResponse( false, errChannelRejected, nil, 0, 0, 0, 0, 0, 0, ), chan3: NewChannelAcceptResponse( false, customError, nil, 0, 0, 0, 0, 0, 0, ), } // Responses is a mocked set of responses from the remote // channel acceptor. responses = map[[32]byte]*lnrpc.ChannelAcceptResponse{ chan1.PendingChannelID: { PendingChanId: chan1.PendingChannelID[:], Accept: true, UpfrontShutdown: testAddr, CsvDelay: 1, MaxHtlcCount: 2, MinAcceptDepth: 3, ReserveSat: 4, InFlightMaxMsat: 5, MinHtlcIn: 6, }, chan2.PendingChannelID: { PendingChanId: chan2.PendingChannelID[:], Accept: false, }, chan3.PendingChannelID: { PendingChanId: chan3.PendingChannelID[:], Accept: false, Error: customError.Error(), }, } ) // Create and start our channel acceptor. testCtx := newChanAcceptorCtx(t, len(queries), responses) testCtx.start() // Dispatch three queries and assert that we get our expected response. // for each. testCtx.queryAndAssert(queries) // Shutdown our acceptor. testCtx.stop() } // TestInvalidResponse tests the case where our remote channel acceptor sends us // an invalid response, so the channel acceptor stream terminates. func TestInvalidResponse(t *testing.T) { var ( chan1 = [32]byte{1} // We make a single query, and expect it to fail with our // generic error because our response is invalid. queries = map[*lnwire.OpenChannel]*ChannelAcceptResponse{ { PendingChannelID: chan1, }: NewChannelAcceptResponse( false, errChannelRejected, nil, 0, 0, 0, 0, 0, 0, ), } // Create a single response which is invalid because it accepts // the channel but also contains an error message. responses = map[[32]byte]*lnrpc.ChannelAcceptResponse{ chan1: { PendingChanId: chan1[:], Accept: true, Error: "has an error as well", }, } ) // Create and start our channel acceptor. testCtx := newChanAcceptorCtx(t, len(queries), responses) testCtx.start() testCtx.queryAndAssert(queries) // We do not expect our channel acceptor to exit because of one invalid // response, so we shutdown and assert here. testCtx.stop() } // TestInvalidReserve tests validation of the channel reserve proposed by the // acceptor against the dust limit that was proposed by the remote peer. func TestInvalidReserve(t *testing.T) { var ( chan1 = [32]byte{1} dustLimit = btcutil.Amount(1000) reserve = dustLimit / 2 // We make a single query, and expect it to fail with our // generic error because channel reserve is too low. queries = map[*lnwire.OpenChannel]*ChannelAcceptResponse{ { PendingChannelID: chan1, DustLimit: dustLimit, }: NewChannelAcceptResponse( false, errChannelRejected, nil, 0, 0, 0, reserve, 0, 0, ), } // Create a single response which is invalid because the // proposed reserve is below our dust limit. responses = map[[32]byte]*lnrpc.ChannelAcceptResponse{ chan1: { PendingChanId: chan1[:], Accept: true, ReserveSat: uint64(reserve), }, } ) // Create and start our channel acceptor. testCtx := newChanAcceptorCtx(t, len(queries), responses) testCtx.start() testCtx.queryAndAssert(queries) // We do not expect our channel acceptor to exit because of one invalid // response, so we shutdown and assert here. testCtx.stop() }