package channeldb

import (
	"bytes"
	"math/rand"
	"net"
	"reflect"
	"runtime"
	"testing"

	"github.com/stretchr/testify/require"

	"github.com/btcsuite/btcd/btcec"
	"github.com/btcsuite/btcd/chaincfg/chainhash"
	"github.com/btcsuite/btcd/wire"
	"github.com/btcsuite/btcutil"
	_ "github.com/btcsuite/btcwallet/walletdb/bdb"
	"github.com/davecgh/go-spew/spew"
	"github.com/lightningnetwork/lnd/channeldb/kvdb"
	"github.com/lightningnetwork/lnd/clock"
	"github.com/lightningnetwork/lnd/keychain"
	"github.com/lightningnetwork/lnd/lntest/channels"
	"github.com/lightningnetwork/lnd/lnwire"
	"github.com/lightningnetwork/lnd/shachain"
)

var (
	key = [chainhash.HashSize]byte{
		0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
		0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
		0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d,
		0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9,
	}
	rev = [chainhash.HashSize]byte{
		0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
		0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
		0x2d, 0xe7, 0x93, 0xe4,
	}
	privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:])

	wireSig, _ = lnwire.NewSigFromSignature(testSig)

	testClock = clock.NewTestClock(testNow)

	// defaultPendingHeight is the default height at which we set
	// channels to pending.
	defaultPendingHeight = 100

	// defaultAddr is the default address that we mark test channels pending
	// with.
	defaultAddr = &net.TCPAddr{
		IP:   net.ParseIP("127.0.0.1"),
		Port: 18555,
	}

	// keyLocIndex is the KeyLocator Index we use for TestKeyLocatorEncoding.
	keyLocIndex = uint32(2049)
)

// testChannelParams is a struct which details the specifics of how a channel
// should be created.
type testChannelParams struct {
	// channel is the channel that will be written to disk.
	channel *OpenChannel

	// addr is the address that the channel will be synced pending with.
	addr *net.TCPAddr

	// pendingHeight is the height that the channel should be recorded as
	// pending.
	pendingHeight uint32

	// openChannel is set to true if the channel should be fully marked as
	// open if this is false, the channel will be left in pending state.
	openChannel bool
}

// testChannelOption is a functional option which can be used to alter the
// default channel that is creates for testing.
type testChannelOption func(params *testChannelParams)

// channelCommitmentOption is an option which allows overwriting of the default
// commitment height and balances. The local boolean can be used to set these
// balances on the local or remote commit.
func channelCommitmentOption(height uint64, localBalance,
	remoteBalance lnwire.MilliSatoshi, local bool) testChannelOption {

	return func(params *testChannelParams) {
		if local {
			params.channel.LocalCommitment.CommitHeight = height
			params.channel.LocalCommitment.LocalBalance = localBalance
			params.channel.LocalCommitment.RemoteBalance = remoteBalance
		} else {
			params.channel.RemoteCommitment.CommitHeight = height
			params.channel.RemoteCommitment.LocalBalance = localBalance
			params.channel.RemoteCommitment.RemoteBalance = remoteBalance
		}
	}
}

// pendingHeightOption is an option which can be used to set the height the
// channel is marked as pending at.
func pendingHeightOption(height uint32) testChannelOption {
	return func(params *testChannelParams) {
		params.pendingHeight = height
	}
}

// openChannelOption is an option which can be used to create a test channel
// that is open.
func openChannelOption() testChannelOption {
	return func(params *testChannelParams) {
		params.openChannel = true
	}
}

// localHtlcsOption is an option which allows setting of htlcs on the local
// commitment.
func localHtlcsOption(htlcs []HTLC) testChannelOption {
	return func(params *testChannelParams) {
		params.channel.LocalCommitment.Htlcs = htlcs
	}
}

// remoteHtlcsOption is an option which allows setting of htlcs on the remote
// commitment.
func remoteHtlcsOption(htlcs []HTLC) testChannelOption {
	return func(params *testChannelParams) {
		params.channel.RemoteCommitment.Htlcs = htlcs
	}
}

// localShutdownOption is an option which sets the local upfront shutdown
// script for the channel.
func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
	return func(params *testChannelParams) {
		params.channel.LocalShutdownScript = addr
	}
}

// remoteShutdownOption is an option which sets the remote upfront shutdown
// script for the channel.
func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
	return func(params *testChannelParams) {
		params.channel.RemoteShutdownScript = addr
	}
}

// fundingPointOption is an option which sets the funding outpoint of the
// channel.
func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
	return func(params *testChannelParams) {
		params.channel.FundingOutpoint = chanPoint
	}
}

// channelIDOption is an option which sets the short channel ID of the channel.
var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
	return func(params *testChannelParams) {
		params.channel.ShortChannelID = chanID
	}
}

// createTestChannel writes a test channel to the database. It takes a set of
// functional options which can be used to overwrite the default of creating
// a pending channel that was broadcast at height 100.
func createTestChannel(t *testing.T, cdb *DB,
	opts ...testChannelOption) *OpenChannel {

	// Create a default set of parameters.
	params := &testChannelParams{
		channel:       createTestChannelState(t, cdb),
		addr:          defaultAddr,
		openChannel:   false,
		pendingHeight: uint32(defaultPendingHeight),
	}

	// Apply all functional options to the test channel params.
	for _, o := range opts {
		o(params)
	}

	// Mark the channel as pending.
	err := params.channel.SyncPending(params.addr, params.pendingHeight)
	if err != nil {
		t.Fatalf("unable to save and serialize channel "+
			"state: %v", err)
	}

	// If the parameters do not specify that we should open the channel
	// fully, we return the pending channel.
	if !params.openChannel {
		return params.channel
	}

	// Mark the channel as open with the short channel id provided.
	err = params.channel.MarkAsOpen(params.channel.ShortChannelID)
	if err != nil {
		t.Fatalf("unable to mark channel open: %v", err)
	}

	return params.channel
}

func createTestChannelState(t *testing.T, cdb *DB) *OpenChannel {
	// Simulate 1000 channel updates.
	producer, err := shachain.NewRevocationProducerFromBytes(key[:])
	if err != nil {
		t.Fatalf("could not get producer: %v", err)
	}
	store := shachain.NewRevocationStore()
	for i := 0; i < 1; i++ {
		preImage, err := producer.AtIndex(uint64(i))
		if err != nil {
			t.Fatalf("could not get "+
				"preimage: %v", err)
		}

		if err := store.AddNextEntry(preImage); err != nil {
			t.Fatalf("could not add entry: %v", err)
		}
	}

	localCfg := ChannelConfig{
		ChannelConstraints: ChannelConstraints{
			DustLimit:        btcutil.Amount(rand.Int63()),
			MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()),
			ChanReserve:      btcutil.Amount(rand.Int63()),
			MinHTLC:          lnwire.MilliSatoshi(rand.Int63()),
			MaxAcceptedHtlcs: uint16(rand.Int31()),
			CsvDelay:         uint16(rand.Int31()),
		},
		MultiSigKey: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
		},
		RevocationBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
		},
		PaymentBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
		},
		DelayBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
		},
		HtlcBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
		},
	}
	remoteCfg := ChannelConfig{
		ChannelConstraints: ChannelConstraints{
			DustLimit:        btcutil.Amount(rand.Int63()),
			MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()),
			ChanReserve:      btcutil.Amount(rand.Int63()),
			MinHTLC:          lnwire.MilliSatoshi(rand.Int63()),
			MaxAcceptedHtlcs: uint16(rand.Int31()),
			CsvDelay:         uint16(rand.Int31()),
		},
		MultiSigKey: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
			KeyLocator: keychain.KeyLocator{
				Family: keychain.KeyFamilyMultiSig,
				Index:  9,
			},
		},
		RevocationBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
			KeyLocator: keychain.KeyLocator{
				Family: keychain.KeyFamilyRevocationBase,
				Index:  8,
			},
		},
		PaymentBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
			KeyLocator: keychain.KeyLocator{
				Family: keychain.KeyFamilyPaymentBase,
				Index:  7,
			},
		},
		DelayBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
			KeyLocator: keychain.KeyLocator{
				Family: keychain.KeyFamilyDelayBase,
				Index:  6,
			},
		},
		HtlcBasePoint: keychain.KeyDescriptor{
			PubKey: privKey.PubKey(),
			KeyLocator: keychain.KeyLocator{
				Family: keychain.KeyFamilyHtlcBase,
				Index:  5,
			},
		},
	}

	chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63()))

	return &OpenChannel{
		ChanType:          SingleFunderBit | FrozenBit,
		ChainHash:         key,
		FundingOutpoint:   wire.OutPoint{Hash: key, Index: rand.Uint32()},
		ShortChannelID:    chanID,
		IsInitiator:       true,
		IsPending:         true,
		IdentityPub:       pubKey,
		Capacity:          btcutil.Amount(10000),
		LocalChanCfg:      localCfg,
		RemoteChanCfg:     remoteCfg,
		TotalMSatSent:     8,
		TotalMSatReceived: 2,
		LocalCommitment: ChannelCommitment{
			CommitHeight:  0,
			LocalBalance:  lnwire.MilliSatoshi(9000),
			RemoteBalance: lnwire.MilliSatoshi(3000),
			CommitFee:     btcutil.Amount(rand.Int63()),
			FeePerKw:      btcutil.Amount(5000),
			CommitTx:      channels.TestFundingTx,
			CommitSig:     bytes.Repeat([]byte{1}, 71),
		},
		RemoteCommitment: ChannelCommitment{
			CommitHeight:  0,
			LocalBalance:  lnwire.MilliSatoshi(3000),
			RemoteBalance: lnwire.MilliSatoshi(9000),
			CommitFee:     btcutil.Amount(rand.Int63()),
			FeePerKw:      btcutil.Amount(5000),
			CommitTx:      channels.TestFundingTx,
			CommitSig:     bytes.Repeat([]byte{1}, 71),
		},
		NumConfsRequired:        4,
		RemoteCurrentRevocation: privKey.PubKey(),
		RemoteNextRevocation:    privKey.PubKey(),
		RevocationProducer:      producer,
		RevocationStore:         store,
		Db:                      cdb,
		Packager:                NewChannelPackager(chanID),
		FundingTxn:              channels.TestFundingTx,
		ThawHeight:              uint32(defaultPendingHeight),
	}
}

func TestOpenChannelPutGetDelete(t *testing.T) {
	t.Parallel()

	cdb, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v", err)
	}
	defer cleanUp()

	// Create the test channel state, with additional htlcs on the local
	// and remote commitment.
	localHtlcs := []HTLC{
		{Signature: testSig.Serialize(),
			Incoming:      true,
			Amt:           10,
			RHash:         key,
			RefundTimeout: 1,
			OnionBlob:     []byte("onionblob"),
		},
	}

	remoteHtlcs := []HTLC{
		{
			Signature:     testSig.Serialize(),
			Incoming:      false,
			Amt:           10,
			RHash:         key,
			RefundTimeout: 1,
			OnionBlob:     []byte("onionblob"),
		},
	}

	state := createTestChannel(
		t, cdb,
		remoteHtlcsOption(remoteHtlcs),
		localHtlcsOption(localHtlcs),
	)

	openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
	if err != nil {
		t.Fatalf("unable to fetch open channel: %v", err)
	}

	newState := openChannels[0]

	// The decoded channel state should be identical to what we stored
	// above.
	if !reflect.DeepEqual(state, newState) {
		t.Fatalf("channel state doesn't match:: %v vs %v",
			spew.Sdump(state), spew.Sdump(newState))
	}

	// We'll also test that the channel is properly able to hot swap the
	// next revocation for the state machine. This tests the initial
	// post-funding revocation exchange.
	nextRevKey, err := btcec.NewPrivateKey(btcec.S256())
	if err != nil {
		t.Fatalf("unable to create new private key: %v", err)
	}
	if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil {
		t.Fatalf("unable to update revocation: %v", err)
	}

	openChannels, err = cdb.FetchOpenChannels(state.IdentityPub)
	if err != nil {
		t.Fatalf("unable to fetch open channel: %v", err)
	}
	updatedChan := openChannels[0]

	// Ensure that the revocation was set properly.
	if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) {
		t.Fatalf("next revocation wasn't updated")
	}

	// Finally to wrap up the test, delete the state of the channel within
	// the database. This involves "closing" the channel which removes all
	// written state, and creates a small "summary" elsewhere within the
	// database.
	closeSummary := &ChannelCloseSummary{
		ChanPoint:         state.FundingOutpoint,
		RemotePub:         state.IdentityPub,
		SettledBalance:    btcutil.Amount(500),
		TimeLockedBalance: btcutil.Amount(10000),
		IsPending:         false,
		CloseType:         CooperativeClose,
	}
	if err := state.CloseChannel(closeSummary); err != nil {
		t.Fatalf("unable to close channel: %v", err)
	}

	// As the channel is now closed, attempting to fetch all open channels
	// for our fake node ID should return an empty slice.
	openChans, err := cdb.FetchOpenChannels(state.IdentityPub)
	if err != nil {
		t.Fatalf("unable to fetch open channels: %v", err)
	}
	if len(openChans) != 0 {
		t.Fatalf("all channels not deleted, found %v", len(openChans))
	}

	// Additionally, attempting to fetch all the open channels globally
	// should yield no results.
	openChans, err = cdb.FetchAllChannels()
	if err != nil {
		t.Fatal("unable to fetch all open chans")
	}
	if len(openChans) != 0 {
		t.Fatalf("all channels not deleted, found %v", len(openChans))
	}
}

// TestOptionalShutdown tests the reading and writing of channels with and
// without optional shutdown script fields.
func TestOptionalShutdown(t *testing.T) {
	local := lnwire.DeliveryAddress([]byte("local shutdown script"))
	remote := lnwire.DeliveryAddress([]byte("remote shutdown script"))

	if _, err := rand.Read(remote); err != nil {
		t.Fatalf("Could not create random script: %v", err)
	}

	tests := []struct {
		name           string
		localShutdown  lnwire.DeliveryAddress
		remoteShutdown lnwire.DeliveryAddress
	}{
		{
			name:           "no shutdown scripts",
			localShutdown:  nil,
			remoteShutdown: nil,
		},
		{
			name:           "local shutdown script",
			localShutdown:  local,
			remoteShutdown: nil,
		},
		{
			name:           "remote shutdown script",
			localShutdown:  nil,
			remoteShutdown: remote,
		},
		{
			name:           "both scripts set",
			localShutdown:  local,
			remoteShutdown: remote,
		},
	}

	for _, test := range tests {
		test := test

		t.Run(test.name, func(t *testing.T) {
			cdb, cleanUp, err := MakeTestDB()
			if err != nil {
				t.Fatalf("unable to make test database: %v", err)
			}
			defer cleanUp()

			// Create a channel with upfront scripts set as
			// specified in the test.
			state := createTestChannel(
				t, cdb,
				localShutdownOption(test.localShutdown),
				remoteShutdownOption(test.remoteShutdown),
			)

			openChannels, err := cdb.FetchOpenChannels(
				state.IdentityPub,
			)
			if err != nil {
				t.Fatalf("unable to fetch open"+
					" channel: %v", err)
			}

			if len(openChannels) != 1 {
				t.Fatalf("Expected one channel open,"+
					" got: %v", len(openChannels))
			}

			if !bytes.Equal(openChannels[0].LocalShutdownScript,
				test.localShutdown) {

				t.Fatalf("Expected local: %x, got: %x",
					test.localShutdown,
					openChannels[0].LocalShutdownScript)
			}

			if !bytes.Equal(openChannels[0].RemoteShutdownScript,
				test.remoteShutdown) {

				t.Fatalf("Expected remote: %x, got: %x",
					test.remoteShutdown,
					openChannels[0].RemoteShutdownScript)
			}
		})
	}
}

func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
	if !reflect.DeepEqual(a, b) {
		_, _, line, _ := runtime.Caller(1)
		t.Fatalf("line %v: commitments don't match: %v vs %v",
			line, spew.Sdump(a), spew.Sdump(b))
	}
}

func TestChannelStateTransition(t *testing.T) {
	t.Parallel()

	cdb, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v", err)
	}
	defer cleanUp()

	// First create a minimal channel, then perform a full sync in order to
	// persist the data.
	channel := createTestChannel(t, cdb)

	// Add some HTLCs which were added during this new state transition.
	// Half of the HTLCs are incoming, while the other half are outgoing.
	var (
		htlcs   []HTLC
		htlcAmt lnwire.MilliSatoshi
	)
	for i := uint32(0); i < 10; i++ {
		var incoming bool
		if i > 5 {
			incoming = true
		}
		htlc := HTLC{
			Signature:     testSig.Serialize(),
			Incoming:      incoming,
			Amt:           10,
			RHash:         key,
			RefundTimeout: i,
			OutputIndex:   int32(i * 3),
			LogIndex:      uint64(i * 2),
			HtlcIndex:     uint64(i),
		}
		htlc.OnionBlob = make([]byte, 10)
		copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10))
		htlcs = append(htlcs, htlc)
		htlcAmt += htlc.Amt
	}

	// Create a new channel delta which includes the above HTLCs, some
	// balance updates, and an increment of the current commitment height.
	// Additionally, modify the signature and commitment transaction.
	newSequence := uint32(129498)
	newSig := bytes.Repeat([]byte{3}, 71)
	newTx := channel.LocalCommitment.CommitTx.Copy()
	newTx.TxIn[0].Sequence = newSequence
	commitment := ChannelCommitment{
		CommitHeight:    1,
		LocalLogIndex:   2,
		LocalHtlcIndex:  1,
		RemoteLogIndex:  2,
		RemoteHtlcIndex: 1,
		LocalBalance:    lnwire.MilliSatoshi(1e8),
		RemoteBalance:   lnwire.MilliSatoshi(1e8),
		CommitFee:       55,
		FeePerKw:        99,
		CommitTx:        newTx,
		CommitSig:       newSig,
		Htlcs:           htlcs,
	}

	// First update the local node's broadcastable state and also add a
	// CommitDiff remote node's as well in order to simulate a proper state
	// transition.
	unsignedAckedUpdates := []LogUpdate{
		{
			LogIndex: 2,
			UpdateMsg: &lnwire.UpdateAddHTLC{
				ChanID:    lnwire.ChannelID{1, 2, 3},
				ExtraData: make([]byte, 0),
			},
		},
	}

	err = channel.UpdateCommitment(&commitment, unsignedAckedUpdates)
	if err != nil {
		t.Fatalf("unable to update commitment: %v", err)
	}

	// Assert that update is correctly written to the database.
	dbUnsignedAckedUpdates, err := channel.UnsignedAckedUpdates()
	if err != nil {
		t.Fatalf("unable to fetch dangling remote updates: %v", err)
	}
	if len(dbUnsignedAckedUpdates) != 1 {
		t.Fatalf("unexpected number of dangling remote updates")
	}
	if !reflect.DeepEqual(
		dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0],
	) {
		t.Fatalf("unexpected update: expected %v, got %v",
			spew.Sdump(unsignedAckedUpdates[0]),
			spew.Sdump(dbUnsignedAckedUpdates))
	}

	// The balances, new update, the HTLCs and the changes to the fake
	// commitment transaction along with the modified signature should all
	// have been updated.
	updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
	if err != nil {
		t.Fatalf("unable to fetch updated channel: %v", err)
	}
	assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
	numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
	if err != nil {
		t.Fatalf("unable to read commitment height from disk: %v", err)
	}
	if numDiskUpdates != uint64(commitment.CommitHeight) {
		t.Fatalf("num disk updates doesn't match: %v vs %v",
			numDiskUpdates, commitment.CommitHeight)
	}

	// Attempting to query for a commitment diff should return
	// ErrNoPendingCommit as we haven't yet created a new state for them.
	_, err = channel.RemoteCommitChainTip()
	if err != ErrNoPendingCommit {
		t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
	}

	// To simulate us extending a new state to the remote party, we'll also
	// create a new commit diff for them.
	remoteCommit := commitment
	remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8)
	remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8)
	remoteCommit.CommitHeight = 1
	commitDiff := &CommitDiff{
		Commitment: remoteCommit,
		CommitSig: &lnwire.CommitSig{
			ChanID:    lnwire.ChannelID(key),
			CommitSig: wireSig,
			HtlcSigs: []lnwire.Sig{
				wireSig,
				wireSig,
			},
			ExtraData: make([]byte, 0),
		},
		LogUpdates: []LogUpdate{
			{
				LogIndex: 1,
				UpdateMsg: &lnwire.UpdateAddHTLC{
					ID:        1,
					Amount:    lnwire.NewMSatFromSatoshis(100),
					Expiry:    25,
					ExtraData: make([]byte, 0),
				},
			},
			{
				LogIndex: 2,
				UpdateMsg: &lnwire.UpdateAddHTLC{
					ID:        2,
					Amount:    lnwire.NewMSatFromSatoshis(200),
					Expiry:    50,
					ExtraData: make([]byte, 0),
				},
			},
		},
		OpenedCircuitKeys: []CircuitKey{},
		ClosedCircuitKeys: []CircuitKey{},
	}
	copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
		bytes.Repeat([]byte{1}, 32))
	copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
		bytes.Repeat([]byte{2}, 32))
	if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
		t.Fatalf("unable to add to commit chain: %v", err)
	}

	// The commitment tip should now match the commitment that we just
	// inserted.
	diskCommitDiff, err := channel.RemoteCommitChainTip()
	if err != nil {
		t.Fatalf("unable to fetch commit diff: %v", err)
	}
	if !reflect.DeepEqual(commitDiff, diskCommitDiff) {
		t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit),
			spew.Sdump(diskCommitDiff))
	}

	// We'll save the old remote commitment as this will be added to the
	// revocation log shortly.
	oldRemoteCommit := channel.RemoteCommitment

	// Next, write to the log which tracks the necessary revocation state
	// needed to rectify any fishy behavior by the remote party. Modify the
	// current uncollapsed revocation state to simulate a state transition
	// by the remote party.
	channel.RemoteCurrentRevocation = channel.RemoteNextRevocation
	newPriv, err := btcec.NewPrivateKey(btcec.S256())
	if err != nil {
		t.Fatalf("unable to generate key: %v", err)
	}
	channel.RemoteNextRevocation = newPriv.PubKey()

	fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
		diskCommitDiff.LogUpdates, nil)

	err = channel.AdvanceCommitChainTail(fwdPkg, nil)
	if err != nil {
		t.Fatalf("unable to append to revocation log: %v", err)
	}

	// At this point, the remote commit chain should be nil, and the posted
	// remote commitment should match the one we added as a diff above.
	if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit {
		t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
	}

	// We should be able to fetch the channel delta created above by its
	// update number with all the state properly reconstructed.
	diskPrevCommit, err := channel.FindPreviousState(
		oldRemoteCommit.CommitHeight,
	)
	if err != nil {
		t.Fatalf("unable to fetch past delta: %v", err)
	}

	// The two deltas (the original vs the on-disk version) should
	// identical, and all HTLC data should properly be retained.
	assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit)

	// The state number recovered from the tail of the revocation log
	// should be identical to this current state.
	logTail, err := channel.RevocationLogTail()
	if err != nil {
		t.Fatalf("unable to retrieve log: %v", err)
	}
	if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
		t.Fatal("update number doesn't match")
	}

	oldRemoteCommit = channel.RemoteCommitment

	// Next modify the posted diff commitment slightly, then create a new
	// commitment diff and advance the tail.
	commitDiff.Commitment.CommitHeight = 2
	commitDiff.Commitment.LocalBalance -= htlcAmt
	commitDiff.Commitment.RemoteBalance += htlcAmt
	commitDiff.LogUpdates = []LogUpdate{}
	if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
		t.Fatalf("unable to add to commit chain: %v", err)
	}

	fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)

	err = channel.AdvanceCommitChainTail(fwdPkg, nil)
	if err != nil {
		t.Fatalf("unable to append to revocation log: %v", err)
	}

	// Once again, fetch the state and ensure it has been properly updated.
	prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight)
	if err != nil {
		t.Fatalf("unable to fetch past delta: %v", err)
	}
	assertCommitmentEqual(t, &oldRemoteCommit, prevCommit)

	// Once again, state number recovered from the tail of the revocation
	// log should be identical to this current state.
	logTail, err = channel.RevocationLogTail()
	if err != nil {
		t.Fatalf("unable to retrieve log: %v", err)
	}
	if logTail.CommitHeight != oldRemoteCommit.CommitHeight {
		t.Fatal("update number doesn't match")
	}

	// The revocation state stored on-disk should now also be identical.
	updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub)
	if err != nil {
		t.Fatalf("unable to fetch updated channel: %v", err)
	}
	if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) {
		t.Fatalf("revocation state was not synced")
	}
	if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) {
		t.Fatalf("revocation state was not synced")
	}

	// Now attempt to delete the channel from the database.
	closeSummary := &ChannelCloseSummary{
		ChanPoint:         channel.FundingOutpoint,
		RemotePub:         channel.IdentityPub,
		SettledBalance:    btcutil.Amount(500),
		TimeLockedBalance: btcutil.Amount(10000),
		IsPending:         false,
		CloseType:         RemoteForceClose,
	}
	if err := updatedChannel[0].CloseChannel(closeSummary); err != nil {
		t.Fatalf("unable to delete updated channel: %v", err)
	}

	// If we attempt to fetch the target channel again, it shouldn't be
	// found.
	channels, err := cdb.FetchOpenChannels(channel.IdentityPub)
	if err != nil {
		t.Fatalf("unable to fetch updated channels: %v", err)
	}
	if len(channels) != 0 {
		t.Fatalf("%v channels, found, but none should be",
			len(channels))
	}

	// Attempting to find previous states on the channel should fail as the
	// revocation log has been deleted.
	_, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight)
	if err == nil {
		t.Fatal("revocation log search should have failed")
	}
}

func TestFetchPendingChannels(t *testing.T) {
	t.Parallel()

	cdb, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v", err)
	}
	defer cleanUp()

	// Create a pending channel that was broadcast at height 99.
	const broadcastHeight = 99
	createTestChannel(t, cdb, pendingHeightOption(broadcastHeight))

	pendingChannels, err := cdb.FetchPendingChannels()
	if err != nil {
		t.Fatalf("unable to list pending channels: %v", err)
	}

	if len(pendingChannels) != 1 {
		t.Fatalf("incorrect number of pending channels: expecting %v,"+
			"got %v", 1, len(pendingChannels))
	}

	// The broadcast height of the pending channel should have been set
	// properly.
	if pendingChannels[0].FundingBroadcastHeight != broadcastHeight {
		t.Fatalf("broadcast height mismatch: expected %v, got %v",
			pendingChannels[0].FundingBroadcastHeight,
			broadcastHeight)
	}

	chanOpenLoc := lnwire.ShortChannelID{
		BlockHeight: 5,
		TxIndex:     10,
		TxPosition:  15,
	}
	err = pendingChannels[0].MarkAsOpen(chanOpenLoc)
	if err != nil {
		t.Fatalf("unable to mark channel as open: %v", err)
	}

	if pendingChannels[0].IsPending {
		t.Fatalf("channel marked open should no longer be pending")
	}

	if pendingChannels[0].ShortChanID() != chanOpenLoc {
		t.Fatalf("channel opening height not updated: expected %v, "+
			"got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
			chanOpenLoc)
	}

	// Next, we'll re-fetch the channel to ensure that the open height was
	// properly set.
	openChans, err := cdb.FetchAllChannels()
	if err != nil {
		t.Fatalf("unable to fetch channels: %v", err)
	}
	if openChans[0].ShortChanID() != chanOpenLoc {
		t.Fatalf("channel opening heights don't match: expected %v, "+
			"got %v", spew.Sdump(openChans[0].ShortChanID()),
			chanOpenLoc)
	}
	if openChans[0].FundingBroadcastHeight != broadcastHeight {
		t.Fatalf("broadcast height mismatch: expected %v, got %v",
			openChans[0].FundingBroadcastHeight,
			broadcastHeight)
	}

	pendingChannels, err = cdb.FetchPendingChannels()
	if err != nil {
		t.Fatalf("unable to list pending channels: %v", err)
	}

	if len(pendingChannels) != 0 {
		t.Fatalf("incorrect number of pending channels: expecting %v,"+
			"got %v", 0, len(pendingChannels))
	}
}

func TestFetchClosedChannels(t *testing.T) {
	t.Parallel()

	cdb, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v", err)
	}
	defer cleanUp()

	// Create an open channel in the database.
	state := createTestChannel(t, cdb, openChannelOption())

	// Next, close the channel by including a close channel summary in the
	// database.
	summary := &ChannelCloseSummary{
		ChanPoint:         state.FundingOutpoint,
		ClosingTXID:       rev,
		RemotePub:         state.IdentityPub,
		Capacity:          state.Capacity,
		SettledBalance:    state.LocalCommitment.LocalBalance.ToSatoshis(),
		TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000,
		CloseType:         RemoteForceClose,
		IsPending:         true,
		LocalChanConfig:   state.LocalChanCfg,
	}
	if err := state.CloseChannel(summary); err != nil {
		t.Fatalf("unable to close channel: %v", err)
	}

	// Query the database to ensure that the channel has now been properly
	// closed. We should get the same result whether querying for pending
	// channels only, or not.
	pendingClosed, err := cdb.FetchClosedChannels(true)
	if err != nil {
		t.Fatalf("failed fetching closed channels: %v", err)
	}
	if len(pendingClosed) != 1 {
		t.Fatalf("incorrect number of pending closed channels: expecting %v,"+
			"got %v", 1, len(pendingClosed))
	}
	if !reflect.DeepEqual(summary, pendingClosed[0]) {
		t.Fatalf("database summaries don't match: expected %v got %v",
			spew.Sdump(summary), spew.Sdump(pendingClosed[0]))
	}
	closed, err := cdb.FetchClosedChannels(false)
	if err != nil {
		t.Fatalf("failed fetching all closed channels: %v", err)
	}
	if len(closed) != 1 {
		t.Fatalf("incorrect number of closed channels: expecting %v, "+
			"got %v", 1, len(closed))
	}
	if !reflect.DeepEqual(summary, closed[0]) {
		t.Fatalf("database summaries don't match: expected %v got %v",
			spew.Sdump(summary), spew.Sdump(closed[0]))
	}

	// Mark the channel as fully closed.
	err = cdb.MarkChanFullyClosed(&state.FundingOutpoint)
	if err != nil {
		t.Fatalf("failed fully closing channel: %v", err)
	}

	// The channel should no longer be considered pending, but should still
	// be retrieved when fetching all the closed channels.
	closed, err = cdb.FetchClosedChannels(false)
	if err != nil {
		t.Fatalf("failed fetching closed channels: %v", err)
	}
	if len(closed) != 1 {
		t.Fatalf("incorrect number of closed channels: expecting %v, "+
			"got %v", 1, len(closed))
	}
	pendingClose, err := cdb.FetchClosedChannels(true)
	if err != nil {
		t.Fatalf("failed fetching channels pending close: %v", err)
	}
	if len(pendingClose) != 0 {
		t.Fatalf("incorrect number of closed channels: expecting %v, "+
			"got %v", 0, len(closed))
	}
}

// TestFetchWaitingCloseChannels ensures that the correct channels that are
// waiting to be closed are returned.
func TestFetchWaitingCloseChannels(t *testing.T) {
	t.Parallel()

	const numChannels = 2
	const broadcastHeight = 99

	// We'll start by creating two channels within our test database. One of
	// them will have their funding transaction confirmed on-chain, while
	// the other one will remain unconfirmed.
	db, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v", err)
	}
	defer cleanUp()

	channels := make([]*OpenChannel, numChannels)
	for i := 0; i < numChannels; i++ {
		// Create a pending channel in the database at the broadcast
		// height.
		channels[i] = createTestChannel(
			t, db, pendingHeightOption(broadcastHeight),
		)
	}

	// We'll only confirm the first one.
	channelConf := lnwire.ShortChannelID{
		BlockHeight: broadcastHeight + 1,
		TxIndex:     10,
		TxPosition:  15,
	}
	if err := channels[0].MarkAsOpen(channelConf); err != nil {
		t.Fatalf("unable to mark channel as open: %v", err)
	}

	// Then, we'll mark the channels as if their commitments were broadcast.
	// This would happen in the event of a force close and should make the
	// channels enter a state of waiting close.
	for _, channel := range channels {
		closeTx := wire.NewMsgTx(2)
		closeTx.AddTxIn(
			&wire.TxIn{
				PreviousOutPoint: channel.FundingOutpoint,
			},
		)

		if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil {
			t.Fatalf("unable to mark commitment broadcast: %v", err)
		}

		// Now try to marking a coop close with a nil tx. This should
		// succeed, but it shouldn't exit when queried.
		if err = channel.MarkCoopBroadcasted(nil, true); err != nil {
			t.Fatalf("unable to mark nil coop broadcast: %v", err)
		}
		_, err := channel.BroadcastedCooperative()
		if err != ErrNoCloseTx {
			t.Fatalf("expected no closing tx error, got: %v", err)
		}

		// Finally, modify the close tx deterministically  and also mark
		// it as coop closed. Later we will test that distinct
		// transactions are returned for both coop and force closes.
		closeTx.TxIn[0].PreviousOutPoint.Index ^= 1
		if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil {
			t.Fatalf("unable to mark coop broadcast: %v", err)
		}
	}

	// Now, we'll fetch all the channels waiting to be closed from the
	// database. We should expect to see both channels above, even if any of
	// them haven't had their funding transaction confirm on-chain.
	waitingCloseChannels, err := db.FetchWaitingCloseChannels()
	if err != nil {
		t.Fatalf("unable to fetch all waiting close channels: %v", err)
	}
	if len(waitingCloseChannels) != numChannels {
		t.Fatalf("expected %d channels waiting to be closed, got %d", 2,
			len(waitingCloseChannels))
	}
	expectedChannels := make(map[wire.OutPoint]struct{})
	for _, channel := range channels {
		expectedChannels[channel.FundingOutpoint] = struct{}{}
	}
	for _, channel := range waitingCloseChannels {
		if _, ok := expectedChannels[channel.FundingOutpoint]; !ok {
			t.Fatalf("expected channel %v to be waiting close",
				channel.FundingOutpoint)
		}

		chanPoint := channel.FundingOutpoint

		// Assert that the force close transaction is retrievable.
		forceCloseTx, err := channel.BroadcastedCommitment()
		if err != nil {
			t.Fatalf("Unable to retrieve commitment: %v", err)
		}

		if forceCloseTx.TxIn[0].PreviousOutPoint != chanPoint {
			t.Fatalf("expected outpoint %v, got %v",
				chanPoint,
				forceCloseTx.TxIn[0].PreviousOutPoint)
		}

		// Assert that the coop close transaction is retrievable.
		coopCloseTx, err := channel.BroadcastedCooperative()
		if err != nil {
			t.Fatalf("unable to retrieve coop close: %v", err)
		}

		chanPoint.Index ^= 1
		if coopCloseTx.TxIn[0].PreviousOutPoint != chanPoint {
			t.Fatalf("expected outpoint %v, got %v",
				chanPoint,
				coopCloseTx.TxIn[0].PreviousOutPoint)
		}
	}
}

// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory
// state of another OpenChannel to reflect a preceding call to MarkOpen on a
// different OpenChannel.
func TestRefreshShortChanID(t *testing.T) {
	t.Parallel()

	cdb, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v", err)
	}
	defer cleanUp()

	// First create a test channel.
	state := createTestChannel(t, cdb)

	// Next, locate the pending channel with the database.
	pendingChannels, err := cdb.FetchPendingChannels()
	if err != nil {
		t.Fatalf("unable to load pending channels; %v", err)
	}

	var pendingChannel *OpenChannel
	for _, channel := range pendingChannels {
		if channel.FundingOutpoint == state.FundingOutpoint {
			pendingChannel = channel
			break
		}
	}
	if pendingChannel == nil {
		t.Fatalf("unable to find pending channel with funding "+
			"outpoint=%v: %v", state.FundingOutpoint, err)
	}

	// Next, simulate the confirmation of the channel by marking it as
	// pending within the database.
	chanOpenLoc := lnwire.ShortChannelID{
		BlockHeight: 105,
		TxIndex:     10,
		TxPosition:  15,
	}

	err = state.MarkAsOpen(chanOpenLoc)
	if err != nil {
		t.Fatalf("unable to mark channel open: %v", err)
	}

	// The short_chan_id of the receiver to MarkAsOpen should reflect the
	// open location, but the other pending channel should remain unchanged.
	if state.ShortChanID() == pendingChannel.ShortChanID() {
		t.Fatalf("pending channel short_chan_ID should not have been " +
			"updated before refreshing short_chan_id")
	}

	// Now that the receiver's short channel id has been updated, check to
	// ensure that the channel packager's source has been updated as well.
	// This ensures that the packager will read and write to buckets
	// corresponding to the new short chan id, instead of the prior.
	if state.Packager.(*ChannelPackager).source != chanOpenLoc {
		t.Fatalf("channel packager source was not updated: want %v, "+
			"got %v", chanOpenLoc,
			state.Packager.(*ChannelPackager).source)
	}

	// Now, refresh the short channel ID of the pending channel.
	err = pendingChannel.RefreshShortChanID()
	if err != nil {
		t.Fatalf("unable to refresh short_chan_id: %v", err)
	}

	// This should result in both OpenChannel's now having the same
	// ShortChanID.
	if state.ShortChanID() != pendingChannel.ShortChanID() {
		t.Fatalf("expected pending channel short_chan_id to be "+
			"refreshed: want %v, got %v", state.ShortChanID(),
			pendingChannel.ShortChanID())
	}

	// Check to ensure that the _other_ OpenChannel channel packager's
	// source has also been updated after the refresh. This ensures that the
	// other packagers will read and write to buckets corresponding to the
	// updated short chan id.
	if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
		t.Fatalf("channel packager source was not updated: want %v, "+
			"got %v", chanOpenLoc,
			pendingChannel.Packager.(*ChannelPackager).source)
	}

	// Check to ensure that this channel is no longer pending and this field
	// is up to date.
	if pendingChannel.IsPending {
		t.Fatalf("channel pending state wasn't updated: want false got true")
	}
}

// TestCloseInitiator tests the setting of close initiator statuses for
// cooperative closes and local force closes.
func TestCloseInitiator(t *testing.T) {
	tests := []struct {
		name string
		// updateChannel is called to update the channel as broadcast,
		// cooperatively or not, based on the test's requirements.
		updateChannel    func(c *OpenChannel) error
		expectedStatuses []ChannelStatus
	}{
		{
			name: "local coop close",
			// Mark the channel as cooperatively closed, initiated
			// by the local party.
			updateChannel: func(c *OpenChannel) error {
				return c.MarkCoopBroadcasted(
					&wire.MsgTx{}, true,
				)
			},
			expectedStatuses: []ChannelStatus{
				ChanStatusLocalCloseInitiator,
				ChanStatusCoopBroadcasted,
			},
		},
		{
			name: "remote coop close",
			// Mark the channel as cooperatively closed, initiated
			// by the remote party.
			updateChannel: func(c *OpenChannel) error {
				return c.MarkCoopBroadcasted(
					&wire.MsgTx{}, false,
				)
			},
			expectedStatuses: []ChannelStatus{
				ChanStatusRemoteCloseInitiator,
				ChanStatusCoopBroadcasted,
			},
		},
		{
			name: "local force close",
			// Mark the channel's commitment as broadcast with
			// local initiator.
			updateChannel: func(c *OpenChannel) error {
				return c.MarkCommitmentBroadcasted(
					&wire.MsgTx{}, true,
				)
			},
			expectedStatuses: []ChannelStatus{
				ChanStatusLocalCloseInitiator,
				ChanStatusCommitBroadcasted,
			},
		},
	}

	for _, test := range tests {
		test := test

		t.Run(test.name, func(t *testing.T) {
			t.Parallel()

			cdb, cleanUp, err := MakeTestDB()
			if err != nil {
				t.Fatalf("unable to make test database: %v",
					err)
			}
			defer cleanUp()

			// Create an open channel.
			channel := createTestChannel(
				t, cdb, openChannelOption(),
			)

			err = test.updateChannel(channel)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}

			// Lookup open channels in the database.
			dbChans, err := fetchChannels(
				cdb, pendingChannelFilter(false),
			)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}
			if len(dbChans) != 1 {
				t.Fatalf("expected 1 channel, got: %v",
					len(dbChans))
			}

			// Check that the statuses that we expect were written
			// to disk.
			for _, status := range test.expectedStatuses {
				if !dbChans[0].HasChanStatus(status) {
					t.Fatalf("expected channel to have "+
						"status: %v, has status: %v",
						status, dbChans[0].chanStatus)
				}
			}
		})
	}
}

// TestCloseChannelStatus tests setting of a channel status on the historical
// channel on channel close.
func TestCloseChannelStatus(t *testing.T) {
	cdb, cleanUp, err := MakeTestDB()
	if err != nil {
		t.Fatalf("unable to make test database: %v",
			err)
	}
	defer cleanUp()

	// Create an open channel.
	channel := createTestChannel(
		t, cdb, openChannelOption(),
	)

	if err := channel.CloseChannel(
		&ChannelCloseSummary{
			ChanPoint: channel.FundingOutpoint,
			RemotePub: channel.IdentityPub,
		}, ChanStatusRemoteCloseInitiator,
	); err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	histChan, err := channel.Db.FetchHistoricalChannel(
		&channel.FundingOutpoint,
	)
	if err != nil {
		t.Fatalf("unexpected error: %v", err)
	}

	if !histChan.HasChanStatus(ChanStatusRemoteCloseInitiator) {
		t.Fatalf("channel should have status")
	}
}

// TestBalanceAtHeight tests lookup of our local and remote balance at a given
// height.
func TestBalanceAtHeight(t *testing.T) {
	const (
		// Values that will be set on our current local commit in
		// memory.
		localHeight        = 2
		localLocalBalance  = 1000
		localRemoteBalance = 1500

		// Values that will be set on our current remote commit in
		// memory.
		remoteHeight        = 3
		remoteLocalBalance  = 2000
		remoteRemoteBalance = 2500

		// Values that will be written to disk in the revocation log.
		oldHeight        = 0
		oldLocalBalance  = 200
		oldRemoteBalance = 300

		// Heights to test error cases.
		unknownHeight   = 1
		unreachedHeight = 4
	)

	// putRevokedState is a helper function used to put commitments is
	// the revocation log bucket to test lookup of balances at heights that
	// are not our current height.
	putRevokedState := func(c *OpenChannel, height uint64, local,
		remote lnwire.MilliSatoshi) error {

		err := kvdb.Update(c.Db, func(tx kvdb.RwTx) error {
			chanBucket, err := fetchChanBucketRw(
				tx, c.IdentityPub, &c.FundingOutpoint,
				c.ChainHash,
			)
			if err != nil {
				return err
			}

			logKey := revocationLogBucket
			logBucket, err := chanBucket.CreateBucketIfNotExists(
				logKey,
			)
			if err != nil {
				return err
			}

			// Make a copy of our current commitment so we do not
			// need to re-fill all the required fields and copy in
			// our new desired values.
			commit := c.LocalCommitment
			commit.CommitHeight = height
			commit.LocalBalance = local
			commit.RemoteBalance = remote

			return appendChannelLogEntry(logBucket, &commit)
		}, func() {})

		return err
	}

	tests := []struct {
		name                  string
		targetHeight          uint64
		expectedLocalBalance  lnwire.MilliSatoshi
		expectedRemoteBalance lnwire.MilliSatoshi
		expectedError         error
	}{
		{
			name:                  "target is current local height",
			targetHeight:          localHeight,
			expectedLocalBalance:  localLocalBalance,
			expectedRemoteBalance: localRemoteBalance,
			expectedError:         nil,
		},
		{
			name:                  "target is current remote height",
			targetHeight:          remoteHeight,
			expectedLocalBalance:  remoteLocalBalance,
			expectedRemoteBalance: remoteRemoteBalance,
			expectedError:         nil,
		},
		{
			name:                  "need to lookup commit",
			targetHeight:          oldHeight,
			expectedLocalBalance:  oldLocalBalance,
			expectedRemoteBalance: oldRemoteBalance,
			expectedError:         nil,
		},
		{
			name:                  "height not found",
			targetHeight:          unknownHeight,
			expectedLocalBalance:  0,
			expectedRemoteBalance: 0,
			expectedError:         ErrLogEntryNotFound,
		},
		{
			name:                  "height not reached",
			targetHeight:          unreachedHeight,
			expectedLocalBalance:  0,
			expectedRemoteBalance: 0,
			expectedError:         errHeightNotReached,
		},
	}

	for _, test := range tests {
		test := test

		t.Run(test.name, func(t *testing.T) {
			t.Parallel()

			cdb, cleanUp, err := MakeTestDB()
			if err != nil {
				t.Fatalf("unable to make test database: %v",
					err)
			}
			defer cleanUp()

			// Create options to set the heights and balances of
			// our local and remote commitments.
			localCommitOpt := channelCommitmentOption(
				localHeight, localLocalBalance,
				localRemoteBalance, true,
			)

			remoteCommitOpt := channelCommitmentOption(
				remoteHeight, remoteLocalBalance,
				remoteRemoteBalance, false,
			)

			// Create an open channel.
			channel := createTestChannel(
				t, cdb, openChannelOption(),
				localCommitOpt, remoteCommitOpt,
			)

			// Write an older commit to disk.
			err = putRevokedState(channel, oldHeight,
				oldLocalBalance, oldRemoteBalance)
			if err != nil {
				t.Fatalf("unexpected error: %v", err)
			}

			local, remote, err := channel.BalancesAtHeight(
				test.targetHeight,
			)
			if err != test.expectedError {
				t.Fatalf("expected: %v, got: %v",
					test.expectedError, err)
			}

			if local != test.expectedLocalBalance {
				t.Fatalf("expected local: %v, got: %v",
					test.expectedLocalBalance, local)
			}

			if remote != test.expectedRemoteBalance {
				t.Fatalf("expected remote: %v, got: %v",
					test.expectedRemoteBalance, remote)
			}
		})
	}
}

// TestHasChanStatus asserts the behavior of HasChanStatus by checking the
// behavior of various status flags in addition to the special case of
// ChanStatusDefault which is treated like a flag in the code base even though
// it isn't.
func TestHasChanStatus(t *testing.T) {
	tests := []struct {
		name   string
		status ChannelStatus
		expHas map[ChannelStatus]bool
	}{
		{
			name:   "default",
			status: ChanStatusDefault,
			expHas: map[ChannelStatus]bool{
				ChanStatusDefault: true,
				ChanStatusBorked:  false,
			},
		},
		{
			name:   "single flag",
			status: ChanStatusBorked,
			expHas: map[ChannelStatus]bool{
				ChanStatusDefault: false,
				ChanStatusBorked:  true,
			},
		},
		{
			name:   "multiple flags",
			status: ChanStatusBorked | ChanStatusLocalDataLoss,
			expHas: map[ChannelStatus]bool{
				ChanStatusDefault:       false,
				ChanStatusBorked:        true,
				ChanStatusLocalDataLoss: true,
			},
		},
	}

	for _, test := range tests {
		test := test

		t.Run(test.name, func(t *testing.T) {
			c := &OpenChannel{
				chanStatus: test.status,
			}

			for status, expHas := range test.expHas {
				has := c.HasChanStatus(status)
				if has == expHas {
					continue
				}

				t.Fatalf("expected chan status to "+
					"have %s? %t, got: %t",
					status, expHas, has)
			}
		})
	}
}

// TestKeyLocatorEncoding tests that we are able to serialize a given
// keychain.KeyLocator. After successfully encoding, we check that the decode
// output arrives at the same initial KeyLocator.
func TestKeyLocatorEncoding(t *testing.T) {
	keyLoc := keychain.KeyLocator{
		Family: keychain.KeyFamilyRevocationRoot,
		Index:  keyLocIndex,
	}

	// First, we'll encode the KeyLocator into a buffer.
	var (
		b   bytes.Buffer
		buf [8]byte
	)

	err := EKeyLocator(&b, &keyLoc, &buf)
	require.NoError(t, err, "unable to encode key locator")

	// Next, we'll attempt to decode the bytes into a new KeyLocator.
	r := bytes.NewReader(b.Bytes())
	var decodedKeyLoc keychain.KeyLocator

	err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
	require.NoError(t, err, "unable to decode key locator")

	// Finally, we'll compare that the original KeyLocator and the decoded
	// version are equal.
	require.Equal(t, keyLoc, decodedKeyLoc)
}