d89f51d1d0
Similarly as with kvdb.View this commits adds a reset closure to the kvdb.Update call in order to be able to reset external state if the underlying db backend needs to retry the transaction.
352 lines
10 KiB
Go
352 lines
10 KiB
Go
package discovery
|
|
|
|
import (
|
|
"bytes"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"os"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"github.com/btcsuite/btcd/btcec"
|
|
"github.com/davecgh/go-spew/spew"
|
|
"github.com/lightningnetwork/lnd/channeldb"
|
|
"github.com/lightningnetwork/lnd/channeldb/kvdb"
|
|
"github.com/lightningnetwork/lnd/lnwire"
|
|
)
|
|
|
|
func createTestMessageStore(t *testing.T) (*MessageStore, func()) {
|
|
t.Helper()
|
|
|
|
tempDir, err := ioutil.TempDir("", "channeldb")
|
|
if err != nil {
|
|
t.Fatalf("unable to create temp dir: %v", err)
|
|
}
|
|
db, err := channeldb.Open(tempDir)
|
|
if err != nil {
|
|
os.RemoveAll(tempDir)
|
|
t.Fatalf("unable to open db: %v", err)
|
|
}
|
|
|
|
cleanUp := func() {
|
|
db.Close()
|
|
os.RemoveAll(tempDir)
|
|
}
|
|
|
|
store, err := NewMessageStore(db)
|
|
if err != nil {
|
|
cleanUp()
|
|
t.Fatalf("unable to initialize message store: %v", err)
|
|
}
|
|
|
|
return store, cleanUp
|
|
}
|
|
|
|
func randPubKey(t *testing.T) *btcec.PublicKey {
|
|
priv, err := btcec.NewPrivateKey(btcec.S256())
|
|
if err != nil {
|
|
t.Fatalf("unable to create private key: %v", err)
|
|
}
|
|
|
|
return priv.PubKey()
|
|
}
|
|
|
|
func randCompressedPubKey(t *testing.T) [33]byte {
|
|
t.Helper()
|
|
|
|
pubKey := randPubKey(t)
|
|
|
|
var compressedPubKey [33]byte
|
|
copy(compressedPubKey[:], pubKey.SerializeCompressed())
|
|
|
|
return compressedPubKey
|
|
}
|
|
|
|
func randAnnounceSignatures() *lnwire.AnnounceSignatures {
|
|
return &lnwire.AnnounceSignatures{
|
|
ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()),
|
|
}
|
|
}
|
|
|
|
func randChannelUpdate() *lnwire.ChannelUpdate {
|
|
return &lnwire.ChannelUpdate{
|
|
ShortChannelID: lnwire.NewShortChanIDFromInt(rand.Uint64()),
|
|
}
|
|
}
|
|
|
|
// TestMessageStoreMessages ensures that messages can be properly queried from
|
|
// the store.
|
|
func TestMessageStoreMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// We'll start by creating our test message store.
|
|
msgStore, cleanUp := createTestMessageStore(t)
|
|
defer cleanUp()
|
|
|
|
// We'll then create some test messages for two test peers, and none for
|
|
// an additional test peer.
|
|
channelUpdate1 := randChannelUpdate()
|
|
announceSignatures1 := randAnnounceSignatures()
|
|
peer1 := randCompressedPubKey(t)
|
|
if err := msgStore.AddMessage(channelUpdate1, peer1); err != nil {
|
|
t.Fatalf("unable to add message: %v", err)
|
|
}
|
|
if err := msgStore.AddMessage(announceSignatures1, peer1); err != nil {
|
|
t.Fatalf("unable to add message: %v", err)
|
|
}
|
|
expectedPeerMsgs1 := map[uint64]lnwire.MessageType{
|
|
channelUpdate1.ShortChannelID.ToUint64(): channelUpdate1.MsgType(),
|
|
announceSignatures1.ShortChannelID.ToUint64(): announceSignatures1.MsgType(),
|
|
}
|
|
|
|
channelUpdate2 := randChannelUpdate()
|
|
peer2 := randCompressedPubKey(t)
|
|
if err := msgStore.AddMessage(channelUpdate2, peer2); err != nil {
|
|
t.Fatalf("unable to add message: %v", err)
|
|
}
|
|
expectedPeerMsgs2 := map[uint64]lnwire.MessageType{
|
|
channelUpdate2.ShortChannelID.ToUint64(): channelUpdate2.MsgType(),
|
|
}
|
|
|
|
peer3 := randCompressedPubKey(t)
|
|
expectedPeerMsgs3 := map[uint64]lnwire.MessageType{}
|
|
|
|
// assertPeerMsgs is a helper closure that we'll use to ensure we
|
|
// retrieve the correct set of messages for a given peer.
|
|
assertPeerMsgs := func(peerMsgs []lnwire.Message,
|
|
expected map[uint64]lnwire.MessageType) {
|
|
|
|
t.Helper()
|
|
|
|
if len(peerMsgs) != len(expected) {
|
|
t.Fatalf("expected %d pending messages, got %d",
|
|
len(expected), len(peerMsgs))
|
|
}
|
|
for _, msg := range peerMsgs {
|
|
var shortChanID uint64
|
|
switch msg := msg.(type) {
|
|
case *lnwire.AnnounceSignatures:
|
|
shortChanID = msg.ShortChannelID.ToUint64()
|
|
case *lnwire.ChannelUpdate:
|
|
shortChanID = msg.ShortChannelID.ToUint64()
|
|
default:
|
|
t.Fatalf("found unexpected message type %T", msg)
|
|
}
|
|
|
|
msgType, ok := expected[shortChanID]
|
|
if !ok {
|
|
t.Fatalf("retrieved message with unexpected ID "+
|
|
"%d from store", shortChanID)
|
|
}
|
|
if msgType != msg.MsgType() {
|
|
t.Fatalf("expected message of type %v, got %v",
|
|
msg.MsgType(), msgType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Then, we'll query the store for the set of messages for each peer and
|
|
// ensure it matches what we expect.
|
|
peers := [][33]byte{peer1, peer2, peer3}
|
|
expectedPeerMsgs := []map[uint64]lnwire.MessageType{
|
|
expectedPeerMsgs1, expectedPeerMsgs2, expectedPeerMsgs3,
|
|
}
|
|
for i, peer := range peers {
|
|
peerMsgs, err := msgStore.MessagesForPeer(peer)
|
|
if err != nil {
|
|
t.Fatalf("unable to retrieve messages: %v", err)
|
|
}
|
|
assertPeerMsgs(peerMsgs, expectedPeerMsgs[i])
|
|
}
|
|
|
|
// Finally, we'll query the store for all of its messages of every peer.
|
|
// Again, each peer should have a set of messages that match what we
|
|
// expect.
|
|
//
|
|
// We'll construct the expected response. Only the first two peers will
|
|
// have messages.
|
|
totalPeerMsgs := make(map[[33]byte]map[uint64]lnwire.MessageType, 2)
|
|
for i := 0; i < 2; i++ {
|
|
totalPeerMsgs[peers[i]] = expectedPeerMsgs[i]
|
|
}
|
|
|
|
msgs, err := msgStore.Messages()
|
|
if err != nil {
|
|
t.Fatalf("unable to retrieve all peers with pending messages: "+
|
|
"%v", err)
|
|
}
|
|
if len(msgs) != len(totalPeerMsgs) {
|
|
t.Fatalf("expected %d peers with messages, got %d",
|
|
len(totalPeerMsgs), len(msgs))
|
|
}
|
|
for peer, peerMsgs := range msgs {
|
|
expected, ok := totalPeerMsgs[peer]
|
|
if !ok {
|
|
t.Fatalf("expected to find pending messages for peer %x",
|
|
peer)
|
|
}
|
|
|
|
assertPeerMsgs(peerMsgs, expected)
|
|
}
|
|
|
|
peerPubKeys, err := msgStore.Peers()
|
|
if err != nil {
|
|
t.Fatalf("unable to retrieve all peers with pending messages: "+
|
|
"%v", err)
|
|
}
|
|
if len(peerPubKeys) != len(totalPeerMsgs) {
|
|
t.Fatalf("expected %d peers with messages, got %d",
|
|
len(totalPeerMsgs), len(peerPubKeys))
|
|
}
|
|
for peerPubKey := range peerPubKeys {
|
|
if _, ok := totalPeerMsgs[peerPubKey]; !ok {
|
|
t.Fatalf("expected to find peer %x", peerPubKey)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMessageStoreUnsupportedMessage ensures that we are not able to add a
|
|
// message which is unsupported, and if a message is found to be unsupported by
|
|
// the current version of the store, that it is properly filtered out from the
|
|
// response.
|
|
func TestMessageStoreUnsupportedMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// We'll start by creating our test message store.
|
|
msgStore, cleanUp := createTestMessageStore(t)
|
|
defer cleanUp()
|
|
|
|
// Create a message that is known to not be supported by the store.
|
|
peer := randCompressedPubKey(t)
|
|
unsupportedMsg := &lnwire.Error{}
|
|
|
|
// Attempting to add it to the store should result in
|
|
// ErrUnsupportedMessage.
|
|
err := msgStore.AddMessage(unsupportedMsg, peer)
|
|
if err != ErrUnsupportedMessage {
|
|
t.Fatalf("expected ErrUnsupportedMessage, got %v", err)
|
|
}
|
|
|
|
// We'll now pretend that the message is actually supported in a future
|
|
// version of the store, so it's able to be added successfully. To
|
|
// replicate this, we'll add the message manually rather than through
|
|
// the existing AddMessage method.
|
|
msgKey := peer[:]
|
|
var rawMsg bytes.Buffer
|
|
if _, err := lnwire.WriteMessage(&rawMsg, unsupportedMsg, 0); err != nil {
|
|
t.Fatalf("unable to serialize message: %v", err)
|
|
}
|
|
err = kvdb.Update(msgStore.db, func(tx kvdb.RwTx) error {
|
|
messageStore := tx.ReadWriteBucket(messageStoreBucket)
|
|
return messageStore.Put(msgKey, rawMsg.Bytes())
|
|
}, func() {})
|
|
if err != nil {
|
|
t.Fatalf("unable to add unsupported message to store: %v", err)
|
|
}
|
|
|
|
// Finally, we'll check that the store can properly filter out messages
|
|
// that are currently unknown to it. We'll make sure this is done for
|
|
// both Messages and MessagesForPeer.
|
|
totalMsgs, err := msgStore.Messages()
|
|
if err != nil {
|
|
t.Fatalf("unable to retrieve messages: %v", err)
|
|
}
|
|
if len(totalMsgs) != 0 {
|
|
t.Fatalf("expected to filter out unsupported message")
|
|
}
|
|
peerMsgs, err := msgStore.MessagesForPeer(peer)
|
|
if err != nil {
|
|
t.Fatalf("unable to retrieve peer messages: %v", err)
|
|
}
|
|
if len(peerMsgs) != 0 {
|
|
t.Fatalf("expected to filter out unsupported message")
|
|
}
|
|
}
|
|
|
|
// TestMessageStoreDeleteMessage ensures that we can properly delete messages
|
|
// from the store.
|
|
func TestMessageStoreDeleteMessage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
msgStore, cleanUp := createTestMessageStore(t)
|
|
defer cleanUp()
|
|
|
|
// assertMsg is a helper closure we'll use to ensure a message
|
|
// does/doesn't exist within the store.
|
|
assertMsg := func(msg lnwire.Message, peer [33]byte, exists bool) {
|
|
t.Helper()
|
|
|
|
storeMsgs, err := msgStore.MessagesForPeer(peer)
|
|
if err != nil {
|
|
t.Fatalf("unable to retrieve messages: %v", err)
|
|
}
|
|
|
|
found := false
|
|
for _, storeMsg := range storeMsgs {
|
|
if reflect.DeepEqual(msg, storeMsg) {
|
|
found = true
|
|
}
|
|
}
|
|
|
|
if found != exists {
|
|
str := "find"
|
|
if !exists {
|
|
str = "not find"
|
|
}
|
|
t.Fatalf("expected to %v message %v", str,
|
|
spew.Sdump(msg))
|
|
}
|
|
}
|
|
|
|
// An AnnounceSignatures message should exist within the store after
|
|
// adding it, and should no longer exists after deleting it.
|
|
peer := randCompressedPubKey(t)
|
|
annSig := randAnnounceSignatures()
|
|
if err := msgStore.AddMessage(annSig, peer); err != nil {
|
|
t.Fatalf("unable to add message: %v", err)
|
|
}
|
|
assertMsg(annSig, peer, true)
|
|
if err := msgStore.DeleteMessage(annSig, peer); err != nil {
|
|
t.Fatalf("unable to delete message: %v", err)
|
|
}
|
|
assertMsg(annSig, peer, false)
|
|
|
|
// The store allows overwriting ChannelUpdates, since there can be
|
|
// multiple versions, so we'll test things slightly different.
|
|
//
|
|
// The ChannelUpdate message should exist within the store after adding
|
|
// it.
|
|
chanUpdate := randChannelUpdate()
|
|
if err := msgStore.AddMessage(chanUpdate, peer); err != nil {
|
|
t.Fatalf("unable to add message: %v", err)
|
|
}
|
|
assertMsg(chanUpdate, peer, true)
|
|
|
|
// Now, we'll create a new version for the same ChannelUpdate message.
|
|
// Adding this one to the store will overwrite the previous one, so only
|
|
// the new one should exist.
|
|
newChanUpdate := randChannelUpdate()
|
|
newChanUpdate.ShortChannelID = chanUpdate.ShortChannelID
|
|
newChanUpdate.Timestamp = chanUpdate.Timestamp + 1
|
|
if err := msgStore.AddMessage(newChanUpdate, peer); err != nil {
|
|
t.Fatalf("unable to add message: %v", err)
|
|
}
|
|
assertMsg(chanUpdate, peer, false)
|
|
assertMsg(newChanUpdate, peer, true)
|
|
|
|
// Deleting the older message should act as a NOP and should NOT delete
|
|
// the newer version as the older no longer exists.
|
|
if err := msgStore.DeleteMessage(chanUpdate, peer); err != nil {
|
|
t.Fatalf("unable to delete message: %v", err)
|
|
}
|
|
assertMsg(chanUpdate, peer, false)
|
|
assertMsg(newChanUpdate, peer, true)
|
|
|
|
// The newer version should no longer exist within the store after
|
|
// deleting it.
|
|
if err := msgStore.DeleteMessage(newChanUpdate, peer); err != nil {
|
|
t.Fatalf("unable to delete message: %v", err)
|
|
}
|
|
assertMsg(newChanUpdate, peer, false)
|
|
}
|