Browse Source
This commit moves and partially refactors the channel acceptor logic
added in c2a6c86e
into the channel acceptor package. This allows us to
use the same logic in our unit tests as the rpcserver, rather than
needing to replicate it in unit tests.
Two changes are made to the existing implementation:
- Rather than having the Accept function run a closure, the closure
originally used in the rpcserver is moved directly into Accept
- The done channel used to signal client exit is moved into the acceptor
because the rpc server does not need knowledge of this detail (in
addition to other fields required for mocking the actual rpc).
Crediting orginal committer as co-author:
Co-authored-by: Crypt-iQ
master
carla
4 years ago
5 changed files with 447 additions and 270 deletions
@ -1,156 +1,207 @@
|
||||
package chanacceptor |
||||
|
||||
import ( |
||||
"bytes" |
||||
"sync/atomic" |
||||
"errors" |
||||
"math/big" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/lightningnetwork/lnd/lnrpc" |
||||
|
||||
"github.com/btcsuite/btcd/btcec" |
||||
"github.com/lightningnetwork/lnd/lnrpc" |
||||
"github.com/lightningnetwork/lnd/lnwire" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func randKey(t *testing.T) *btcec.PublicKey { |
||||
t.Helper() |
||||
const testTimeout = time.Second |
||||
|
||||
priv, err := btcec.NewPrivateKey(btcec.S256()) |
||||
if err != nil { |
||||
t.Fatalf("unable to generate new public key") |
||||
} |
||||
type channelAcceptorCtx struct { |
||||
t *testing.T |
||||
|
||||
return priv.PubKey() |
||||
} |
||||
// extRequests is the channel that we send our channel accept requests
|
||||
// into, this channel mocks sending of a request to the rpc acceptor.
|
||||
// This channel should be buffered with the number of requests we want
|
||||
// to send so that it does not block (like a rpc stream).
|
||||
extRequests chan []byte |
||||
|
||||
// requestInfo encapsulates the information sent from the RPCAcceptor to the
|
||||
// receiver on the other end of the stream.
|
||||
type requestInfo struct { |
||||
chanReq *ChannelAcceptRequest |
||||
responseChan chan lnrpc.ChannelAcceptResponse |
||||
} |
||||
// responses is a map of pending channel IDs to the response which we
|
||||
// wish to mock the remote channel acceptor sending.
|
||||
responses map[[32]byte]*lnrpc.ChannelAcceptResponse |
||||
|
||||
var defaultAcceptTimeout = 5 * time.Second |
||||
// acceptor is the channel acceptor we create for the test.
|
||||
acceptor *RPCAcceptor |
||||
|
||||
func acceptAndIncrementCtr(rpc ChannelAcceptor, req *ChannelAcceptRequest, |
||||
ctr *uint32, success chan struct{}) { |
||||
// errChan is a channel that the error the channel acceptor exits with
|
||||
// is sent into.
|
||||
errChan chan error |
||||
|
||||
// quit is a channel that can be used to shutdown the channel acceptor
|
||||
// and return errShuttingDown.
|
||||
quit chan struct{} |
||||
} |
||||
|
||||
result := rpc.Accept(req) |
||||
if !result { |
||||
return |
||||
func newChanAcceptorCtx(t *testing.T, acceptCallCount int, |
||||
responses map[[32]byte]*lnrpc.ChannelAcceptResponse) *channelAcceptorCtx { |
||||
|
||||
testCtx := &channelAcceptorCtx{ |
||||
t: t, |
||||
extRequests: make(chan []byte, acceptCallCount), |
||||
responses: responses, |
||||
errChan: make(chan error), |
||||
quit: make(chan struct{}), |
||||
} |
||||
|
||||
val := atomic.AddUint32(ctr, 1) |
||||
if val == 3 { |
||||
success <- struct{}{} |
||||
testCtx.acceptor = NewRPCAcceptor( |
||||
testCtx.receiveResponse, testCtx.sendRequest, testTimeout*5, |
||||
testCtx.quit, |
||||
) |
||||
|
||||
return testCtx |
||||
} |
||||
|
||||
// sendRequest mocks sending a request to the channel acceptor.
|
||||
func (c *channelAcceptorCtx) sendRequest(request *lnrpc.ChannelAcceptRequest) error { |
||||
select { |
||||
case c.extRequests <- request.PendingChanId: |
||||
|
||||
case <-time.After(testTimeout): |
||||
c.t.Fatalf("timeout sending request: %v", request.PendingChanId) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// TestMultipleRPCClients tests that the RPCAcceptor is able to handle multiple
|
||||
// callers to its Accept method and respond to them correctly.
|
||||
func TestRPCMultipleAcceptClients(t *testing.T) { |
||||
// receiveResponse mocks sending of a response from the channel acceptor.
|
||||
func (c *channelAcceptorCtx) receiveResponse() (*lnrpc.ChannelAcceptResponse, |
||||
error) { |
||||
|
||||
var ( |
||||
node = randKey(t) |
||||
select { |
||||
case id := <-c.extRequests: |
||||
scratch := [32]byte{} |
||||
copy(scratch[:], id) |
||||
|
||||
firstOpenReq = &ChannelAcceptRequest{ |
||||
Node: node, |
||||
OpenChanMsg: &lnwire.OpenChannel{ |
||||
PendingChannelID: [32]byte{0}, |
||||
}, |
||||
} |
||||
resp, ok := c.responses[scratch] |
||||
assert.True(c.t, ok) |
||||
|
||||
secondOpenReq = &ChannelAcceptRequest{ |
||||
Node: node, |
||||
OpenChanMsg: &lnwire.OpenChannel{ |
||||
PendingChannelID: [32]byte{1}, |
||||
}, |
||||
} |
||||
return resp, nil |
||||
|
||||
thirdOpenReq = &ChannelAcceptRequest{ |
||||
Node: node, |
||||
OpenChanMsg: &lnwire.OpenChannel{ |
||||
PendingChannelID: [32]byte{2}, |
||||
}, |
||||
} |
||||
case <-time.After(testTimeout): |
||||
c.t.Fatalf("timeout receiving request") |
||||
return nil, errors.New("receiveResponse timeout") |
||||
|
||||
counter uint32 |
||||
) |
||||
// Exit if our test acceptor closes the done channel, which indicates
|
||||
// that the acceptor is shutting down.
|
||||
case <-c.acceptor.done: |
||||
return nil, errors.New("acceptor shutting down") |
||||
} |
||||
} |
||||
|
||||
quit := make(chan struct{}) |
||||
defer close(quit) |
||||
// start runs our channel acceptor in a goroutine which sends its exit error
|
||||
// into our test error channel.
|
||||
func (c *channelAcceptorCtx) start() { |
||||
go func() { |
||||
c.errChan <- c.acceptor.Run() |
||||
}() |
||||
} |
||||
|
||||
// Create channels to handle requests and successes.
|
||||
requests := make(chan *requestInfo) |
||||
successChan := make(chan struct{}) |
||||
errChan := make(chan struct{}, 4) |
||||
// stop shuts down the test's channel acceptor and asserts that it exits with
|
||||
// our expected error.
|
||||
func (c *channelAcceptorCtx) stop() { |
||||
close(c.quit) |
||||
|
||||
// demultiplexReq is a closure used to abstract the RPCAcceptor's request
|
||||
// and response logic.
|
||||
demultiplexReq := func(req *ChannelAcceptRequest) bool { |
||||
respChan := make(chan lnrpc.ChannelAcceptResponse, 1) |
||||
select { |
||||
case actual := <-c.errChan: |
||||
assert.Equal(c.t, errShuttingDown, actual) |
||||
|
||||
newRequest := &requestInfo{ |
||||
chanReq: req, |
||||
responseChan: respChan, |
||||
} |
||||
case <-time.After(testTimeout): |
||||
c.t.Fatal("timeout waiting for acceptor to exit") |
||||
} |
||||
} |
||||
|
||||
// Send the newRequest to the requests channel.
|
||||
select { |
||||
case requests <- newRequest: |
||||
case <-quit: |
||||
return false |
||||
// queryAndAssert takes a map of open channel requests which we want to call
|
||||
// Accept for to the outcome we expect from the acceptor, dispatches each
|
||||
// request in a goroutine and then asserts that we get the outcome we expect.
|
||||
func (c *channelAcceptorCtx) queryAndAssert(queries map[*lnwire.OpenChannel]bool) { |
||||
var ( |
||||
node = &btcec.PublicKey{ |
||||
X: big.NewInt(1), |
||||
Y: big.NewInt(1), |
||||
} |
||||
|
||||
// Receive the response and verify that the PendingChanId matches
|
||||
// the ID found in the ChannelAcceptRequest. If no response has been
|
||||
// received in defaultAcceptTimeout, then return false.
|
||||
responses = make(chan struct{}) |
||||
) |
||||
|
||||
for request, expected := range queries { |
||||
request := request |
||||
expected := expected |
||||
|
||||
go func() { |
||||
resp := c.acceptor.Accept(&ChannelAcceptRequest{ |
||||
Node: node, |
||||
OpenChanMsg: request, |
||||
}) |
||||
assert.Equal(c.t, expected, resp) |
||||
responses <- struct{}{} |
||||
}() |
||||
} |
||||
|
||||
// Wait for each of our requests to return a response before we exit.
|
||||
for i := 0; i < len(queries); i++ { |
||||
select { |
||||
case resp := <-respChan: |
||||
pendingID := req.OpenChanMsg.PendingChannelID |
||||
if !bytes.Equal(pendingID[:], resp.PendingChanId) { |
||||
errChan <- struct{}{} |
||||
return false |
||||
} |
||||
|
||||
return resp.Accept |
||||
case <-time.After(defaultAcceptTimeout): |
||||
errChan <- struct{}{} |
||||
return false |
||||
case <-quit: |
||||
return false |
||||
case <-responses: |
||||
case <-time.After(testTimeout): |
||||
c.t.Fatalf("did not receive response") |
||||
} |
||||
} |
||||
} |
||||
|
||||
rpcAcceptor := NewRPCAcceptor(demultiplexReq) |
||||
|
||||
// Now we call the Accept method for each request.
|
||||
go func() { |
||||
acceptAndIncrementCtr(rpcAcceptor, firstOpenReq, &counter, successChan) |
||||
}() |
||||
|
||||
go func() { |
||||
acceptAndIncrementCtr(rpcAcceptor, secondOpenReq, &counter, successChan) |
||||
}() |
||||
// TestMultipleAcceptClients tests that the RPC acceptor is capable of handling
|
||||
// multiple requests to its Accept function and responding to them correctly.
|
||||
func TestMultipleAcceptClients(t *testing.T) { |
||||
var ( |
||||
chan1 = &lnwire.OpenChannel{ |
||||
PendingChannelID: [32]byte{1}, |
||||
} |
||||
chan2 = &lnwire.OpenChannel{ |
||||
PendingChannelID: [32]byte{2}, |
||||
} |
||||
chan3 = &lnwire.OpenChannel{ |
||||
PendingChannelID: [32]byte{3}, |
||||
} |
||||
|
||||
go func() { |
||||
acceptAndIncrementCtr(rpcAcceptor, thirdOpenReq, &counter, successChan) |
||||
}() |
||||
// Queries is a map of the channel IDs we will query Accept
|
||||
// with, and the set of outcomes we expect.
|
||||
queries = map[*lnwire.OpenChannel]bool{ |
||||
chan1: true, |
||||
chan2: false, |
||||
chan3: false, |
||||
} |
||||
|
||||
for { |
||||
select { |
||||
case newRequest := <-requests: |
||||
newResponse := lnrpc.ChannelAcceptResponse{ |
||||
// Responses is a mocked set of responses from the remote
|
||||
// channel acceptor.
|
||||
responses = map[[32]byte]*lnrpc.ChannelAcceptResponse{ |
||||
chan1.PendingChannelID: { |
||||
PendingChanId: chan1.PendingChannelID[:], |
||||
Accept: true, |
||||
PendingChanId: newRequest.chanReq.OpenChanMsg.PendingChannelID[:], |
||||
} |
||||
|
||||
newRequest.responseChan <- newResponse |
||||
case <-errChan: |
||||
t.Fatalf("unable to accept ChannelAcceptRequest") |
||||
case <-successChan: |
||||
return |
||||
case <-quit: |
||||
}, |
||||
chan2.PendingChannelID: { |
||||
PendingChanId: chan2.PendingChannelID[:], |
||||
Accept: false, |
||||
}, |
||||
chan3.PendingChannelID: { |
||||
PendingChanId: chan3.PendingChannelID[:], |
||||
Accept: false, |
||||
}, |
||||
} |
||||
} |
||||
) |
||||
|
||||
// Create and start our channel acceptor.
|
||||
testCtx := newChanAcceptorCtx(t, len(queries), responses) |
||||
testCtx.start() |
||||
|
||||
// Dispatch three queries and assert that we get our expected response.
|
||||
// for each.
|
||||
testCtx.queryAndAssert(queries) |
||||
|
||||
// Shutdown our acceptor.
|
||||
testCtx.stop() |
||||
} |
||||
|
@ -0,0 +1,32 @@
|
||||
package chanacceptor |
||||
|
||||
import ( |
||||
"github.com/btcsuite/btclog" |
||||
"github.com/lightningnetwork/lnd/build" |
||||
) |
||||
|
||||
// Subsystem defines the logging code for this subsystem.
|
||||
const Subsystem = "CHAC" |
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
// means the package will not perform any logging by default until the caller
|
||||
// requests it.
|
||||
var log btclog.Logger |
||||
|
||||
// The default amount of logging is none.
|
||||
func init() { |
||||
UseLogger(build.NewSubLogger(Subsystem, nil)) |
||||
} |
||||
|
||||
// DisableLog disables all library log output. Logging output is disabled
|
||||
// by default until UseLogger is called.
|
||||
func DisableLog() { |
||||
UseLogger(btclog.Disabled) |
||||
} |
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
// This should be used in preference to SetLogWriter if the caller is also
|
||||
// using btclog.
|
||||
func UseLogger(logger btclog.Logger) { |
||||
log = logger |
||||
} |
Loading…
Reference in new issue