From 701eb9d4f4862b0e1672de4a3bffc217d4bd9de2 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 16 Jan 2018 19:38:35 -0800 Subject: [PATCH] contractcourt: add new briefcase.go file to house persistent arbitrator state In this commit, we add a new file: briefcase.go. The contents of this file are the ArbitratorLog. This log will be used by the internal state machine of each Channel Arbitrator to ensure that each state transition is fully reflected on-disk, to ensure that the state machine is durable and able to survive restarts. This commit also adds a new implementation of the ArbitratorLog interface backed by boltdb. --- contractcourt/briefcase.go | 1044 +++++++++++++++++++++++++++++++ contractcourt/briefcase_test.go | 802 ++++++++++++++++++++++++ 2 files changed, 1846 insertions(+) create mode 100644 contractcourt/briefcase.go create mode 100644 contractcourt/briefcase_test.go diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go new file mode 100644 index 00000000..c4de7cf2 --- /dev/null +++ b/contractcourt/briefcase.go @@ -0,0 +1,1044 @@ +package contractcourt + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/boltdb/bolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/roasbeef/btcd/chaincfg/chainhash" + "github.com/roasbeef/btcd/wire" +) + +// ContractResolutions is a wrapper struct around the two forms of resolutions +// we may need to carry out once a contract is closing: resolving the +// commitment output, and resolving any incoming+outgoing HTLC's still present +// in the commitment. +type ContractResolutions struct { + // CommitHash is the txid of the commitment transaction. + CommitHash chainhash.Hash + + // CommitResolution contains all data required to fully resolve a + // commitment output. + CommitResolution *lnwallet.CommitOutputResolution + + // HtlcResolutions contains all data required to fully resolve any + // incoming+outgoing HTLC's present within the commitment transaction. + HtlcResolutions lnwallet.HtlcResolutions +} + +// IsEmpty returns true if the set of resolutions is "empty". A resolution is +// empty if: our commitment output has been trimmed, and we don't have any +// incoming or outgoing HTLC's active. +func (c *ContractResolutions) IsEmpty() bool { + return c.CommitResolution == nil && + len(c.HtlcResolutions.IncomingHTLCs) == 0 && + len(c.HtlcResolutions.OutgoingHTLCs) == 0 +} + +// ArbitratorLog is the primary source of persistent storage for the +// ChannelArbitrator. The log stores the current state of the +// ChannelArbitrator's internal state machine, any items that are required to +// properly make a state transition, and any unresolved contracts. +type ArbitratorLog interface { + // TODO(roasbeef): document on interface the errors expected to be + // returned + + // CurrentState returns the current state of the ChannelArbitrator. + CurrentState() (ArbitratorState, error) + + // CommitState persists, the current state of the chain attendant. + CommitState(ArbitratorState) error + + // InsertUnresolvedContracts inserts a set of unresolved contracts into + // the log. The log will then persistently store each contract until + // they've been swapped out, or resolved. + InsertUnresolvedContracts(...ContractResolver) error + + // FetchUnresolvedContracts returns all unresolved contracts that have + // been previously written to the log. + FetchUnresolvedContracts() ([]ContractResolver, error) + + // SwapContract performs an atomic swap of the old contract for the new + // contract. This method is used when after a contract has been fully + // resolved, it produces another contract that needs to be resolved. + SwapContract(old ContractResolver, new ContractResolver) error + + // ResolveContract marks a contract as fully resolved. Once a contract + // has been fully resolved, it is deleted from persistent storage. + ResolveContract(ContractResolver) error + + // LogContractResolutions stores a complete contract resolution for the + // contract under watch. This method will be called once the + // ChannelArbitrator either force closes a channel, or detects that the + // remote party has broadcast their commitment on chain. + LogContractResolutions(*ContractResolutions) error + + // FetchContractResolutions fetches the set of previously stored + // contract resolutions from persistent storage. + FetchContractResolutions() (*ContractResolutions, error) + + // LogChainActions stores a set of chain actions which are derived from + // our set of active contracts, and the on-chain state. We'll write + // this et of cations when: we decide to go on-chain to resolve a + // contract, or we detect that the remote party has gone on-chain. + LogChainActions(ChainActionMap) error + + // FetchChainActions attempts to fetch the set of previously stored + // chain actions. We'll use this upon restart to properly advance our + // state machine forward. + FetchChainActions() (ChainActionMap, error) + + // WipeHistory is to be called ONLY once *all* contracts have been + // fully resolved, and the channel closure if finalized. This method + // will delete all on-disk state within the persistent log. + WipeHistory() error +} + +// ArbitratorState is a enum that details the current state of the +// ChannelArbitrator's state machine. +type ArbitratorState uint8 + +const ( + // StateDefault is the default state. In this state, no major actions + // need to be executed. + StateDefault ArbitratorState = 0 + + // StateBroadcastCommit is a state that indicates that the attendant + // has decided to broadcast the commitment transaction, but hasn't done + // so yet. + StateBroadcastCommit ArbitratorState = 1 + + // StateContractClose is a state that indicates the contract has + // already been "closed". At this point, we can now examine our active + // contracts, in order to create the proper resolver for each one. + StateContractClosed ArbitratorState = 2 + + // StateWaitingFullResolution is a state that indicates that the + // commitment transaction has been broadcast, and the attendant is now + // waiting for all unresolved contracts to be fully resolved. + StateWaitingFullResolution ArbitratorState = 3 + + // StateFullyResolved is the final state of the attendant. In this + // state, all related contracts have been resolved, and the attendant + // can now be garbage collected. + StateFullyResolved ArbitratorState = 4 + + // StateError is the only error state of the resolver. If we enter this + // state, then we cannot proceed with manual intervention as a state + // transition failed. + StateError ArbitratorState = 5 +) + +// String returns a human readable string describing the ArbitratorState. +func (a ArbitratorState) String() string { + switch a { + case StateDefault: + return "StateDefault" + + case StateBroadcastCommit: + return "StateBroadcastCommit" + + case StateContractClosed: + return "StateContractClosed" + + case StateWaitingFullResolution: + return "StateWaitingFullResolution" + + case StateFullyResolved: + return "StateFullyResolved" + + case StateError: + return "StateError" + + default: + return "unknown state" + } +} + +// resolverType is an enum that enumerates the various types of resolvers. When +// writing resolvers to disk, we prepend this to the raw bytes stroed. This +// allows us to properly decode the resolver into the proper type. +type resolverType uint8 + +const ( + // resolverTimeout is the type of a resolver that's tasked with + // resolving a outgoing HTLC that is very close to timing out. + resolverTimeout = 0 + + // resolverSuccess is the type of a resolver that's tasked with + // resolving an incoming HTLC that we already know the preimage of. + resolverSuccess = 1 + + // resolverOutgoingContest is the type of a resolver that's tasked with + // resolving an outgoing HTLC that hasn't yet timed out. + resolverOutgoingContest = 2 + + // resolverIncomingContest is the type of a resolver that's tasked with + // resolving an incoming HTLC that we don't yet know the preimage to. + resolverIncomingContest = 3 + + // resolverUnilateralSweep is the type of resolver that's tasked with + // sweeping out direct commitment output form the remote party's + // commitment transaction. + resolverUnilateralSweep = 4 +) + +// resolverIDLen is the size of the resolver ID key. This is 36 bytes as we get +// 32 bytes from the hash of the prev tx, and 4 bytes for the output index. +const resolverIDLen = 36 + +// resolverID is a key that uniquely identifies a resolver within a particular +// chain. For this value we use the full outpoint of the resolver. +type resolverID [resolverIDLen]byte + +// newResolverID returns a resolverID given the outpoint of a contract. +func newResolverID(op wire.OutPoint) resolverID { + var r resolverID + + copy(r[:], op.Hash[:]) + + endian.PutUint32(r[32:], op.Index) + + return r +} + +// logScope is a key that we use to scope the storage of a ChannelArbitrator +// within the global log. We use this key to create a unique bucket within the +// database and ensure that we don't have any key collisions. The log's scope +// is define as: chainHash || chanPoint, where chanPoint is the chan point of +// the original channel. +type logScope [32 + 36]byte + +// newLogScope creates a new logScope key from the passed chainhash and +// chanPoint. +func newLogScope(chain chainhash.Hash, op wire.OutPoint) (*logScope, error) { + var l logScope + b := bytes.NewBuffer(l[0:0]) + + if _, err := b.Write(chain[:]); err != nil { + return nil, err + } + if _, err := b.Write(op.Hash[:]); err != nil { + return nil, err + } + + if err := binary.Write(b, endian, op.Index); err != nil { + return nil, err + } + + return &l, nil +} + +var ( + // stateKey is the key that we use to store the current state of the + // arbitrator. + stateKey = []byte("state") + + // contractsBucketKey is the bucket within the logScope that will store + // all the active unresolved contracts. + contractsBucketKey = []byte("contractkey") + + // resolutionsKey is the key under the logScope that we'll use to store + // the full set of resolutions for a channel. + resolutionsKey = []byte("resolutions") + + // actionsBucketKey is the key under the logScope that we'll use to + // store all chain actions once they're determined. + actionsBucketKey = []byte("chain-actions") +) + +var ( + // errScopeBucketNoExist is returned when we can't find the proper + // bucket for an arbitrator's scope. + errScopeBucketNoExist = fmt.Errorf("scope bucket not found") + + // errNoContracts is returned when no contracts are found within the + // log. + errNoContracts = fmt.Errorf("no stored contracts") + + // errNoResolutions is returned when the log doesn't contain any active + // chain resolutions. + errNoResolutions = fmt.Errorf("no contract resolutions exist") + + // errNoActions is retuned when the log doesn't contain any stored + // chain actions. + errNoActions = fmt.Errorf("no chain actions exist") +) + +// boltArbitratorLog is an implementation of the ArbitratorLog interface backed +// by a bolt DB instance. +type boltArbitratorLog struct { + db *bolt.DB + + cfg ChannelArbitratorConfig + + scopeKey logScope +} + +// newBoltArbitratorLog returns a new instance of the boltArbitratorLog given +// an arbitrator config, and the items needed to create its log scope. +func newBoltArbitratorLog(db *bolt.DB, cfg ChannelArbitratorConfig, + chainHash chainhash.Hash, chanPoint wire.OutPoint) (*boltArbitratorLog, error) { + + scope, err := newLogScope(chainHash, chanPoint) + if err != nil { + return nil, err + } + + return &boltArbitratorLog{ + db: db, + cfg: cfg, + scopeKey: *scope, + }, nil +} + +// A compile time check to ensure boltArbitratorLog meets the ArbitratorLog +// interface. +var _ ArbitratorLog = (*boltArbitratorLog)(nil) + +func fetchContractReadBucket(tx *bolt.Tx, scopeKey []byte) (*bolt.Bucket, error) { + scopeBucket := tx.Bucket(scopeKey) + if scopeBucket == nil { + return nil, errScopeBucketNoExist + } + + contractBucket := scopeBucket.Bucket(contractsBucketKey) + if contractBucket == nil { + return nil, errNoContracts + } + + return contractBucket, nil +} + +func fetchContractWriteBucket(tx *bolt.Tx, scopeKey []byte) (*bolt.Bucket, error) { + scopeBucket, err := tx.CreateBucketIfNotExists(scopeKey) + if err != nil { + return nil, err + } + + contractBucket, err := scopeBucket.CreateBucketIfNotExists( + contractsBucketKey, + ) + if err != nil { + return nil, err + } + + return contractBucket, nil +} + +// writeResolver is a helper method that writes a contract resolver and stores +// it it within the passed contractBucket using its unique resolutionsKey key. +func (b *boltArbitratorLog) writeResolver(contractBucket *bolt.Bucket, + res ContractResolver) error { + + // First, we'll write to the buffer the type of this resolver. Using + // this byte, we can later properly deserialize the resolver properly. + var ( + buf bytes.Buffer + rType uint8 + ) + switch res.(type) { + case *htlcTimeoutResolver: + rType = resolverTimeout + case *htlcSuccessResolver: + rType = resolverSuccess + case *htlcOutgoingContestResolver: + rType = resolverOutgoingContest + case *htlcIncomingContestResolver: + rType = resolverIncomingContest + case *commitSweepResolver: + rType = resolverUnilateralSweep + } + if _, err := buf.Write([]byte{byte(rType)}); err != nil { + return err + } + + // With the type of the resolver written, we can then write out the raw + // bytes of the resolver itself. + if err := res.Encode(&buf); err != nil { + return err + } + + resKey := res.ResolverKey() + + return contractBucket.Put(resKey, buf.Bytes()) +} + +// CurrentState returns the current state of the ChannelArbitrator. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { + var s ArbitratorState + err := b.db.View(func(tx *bolt.Tx) error { + scopeBucket := tx.Bucket(b.scopeKey[:]) + if scopeBucket == nil { + return errScopeBucketNoExist + } + + stateBytes := scopeBucket.Get(stateKey) + if stateBytes == nil { + return nil + } + + s = ArbitratorState(stateBytes[0]) + return nil + }) + if err != nil && err != errScopeBucketNoExist { + return s, err + } + + return s, nil +} + +// CommitState persists, the current state of the chain attendant. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) CommitState(s ArbitratorState) error { + return b.db.Batch(func(tx *bolt.Tx) error { + scopeBucket, err := tx.CreateBucketIfNotExists(b.scopeKey[:]) + if err != nil { + return err + } + + return scopeBucket.Put(stateKey[:], []byte{uint8(s)}) + }) +} + +// FetchUnresolvedContracts returns all unresolved contracts that have been +// previously written to the log. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, error) { + resKit := ResolverKit{ + ChannelArbitratorConfig: b.cfg, + Checkpoint: b.checkpointContract, + } + var contracts []ContractResolver + err := b.db.View(func(tx *bolt.Tx) error { + contractBucket, err := fetchContractReadBucket(tx, b.scopeKey[:]) + if err != nil { + return err + } + + return contractBucket.ForEach(func(resKey, resBytes []byte) error { + if len(resKey) != resolverIDLen { + return nil + } + + var res ContractResolver + + // We'll snip off the first byte of the raw resolver + // bytes in order to extract what type of resolver + // we're about to encode. + resType := resBytes[0] + + // Then we'll create a reader using the remaining + // bytes. + resReader := bytes.NewReader(resBytes[1:]) + + switch resType { + case resolverTimeout: + timeoutRes := &htlcTimeoutResolver{} + if err := timeoutRes.Decode(resReader); err != nil { + return err + } + timeoutRes.AttachResolverKit(resKit) + + res = timeoutRes + + case resolverSuccess: + successRes := &htlcSuccessResolver{} + if err := successRes.Decode(resReader); err != nil { + return err + } + + res = successRes + + case resolverOutgoingContest: + outContestRes := &htlcOutgoingContestResolver{ + htlcTimeoutResolver: htlcTimeoutResolver{}, + } + if err := outContestRes.Decode(resReader); err != nil { + return err + } + + res = outContestRes + + case resolverIncomingContest: + inContestRes := &htlcIncomingContestResolver{ + htlcSuccessResolver: htlcSuccessResolver{}, + } + if err := inContestRes.Decode(resReader); err != nil { + return err + } + + res = inContestRes + + case resolverUnilateralSweep: + sweepRes := &commitSweepResolver{} + if err := sweepRes.Decode(resReader); err != nil { + return err + } + + res = sweepRes + + default: + return fmt.Errorf("unknown resolver type: %v", resType) + } + + resKit.Quit = make(chan struct{}) + res.AttachResolverKit(resKit) + contracts = append(contracts, res) + return nil + }) + }) + if err != nil && err != errScopeBucketNoExist && err != errNoContracts { + return nil, err + } + + return contracts, nil +} + +// InsertUnresolvedContracts inserts a set of unresolved contracts into the +// log. The log will then persistently store each contract until they've been +// swapped out, or resolved. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) InsertUnresolvedContracts(resolvers ...ContractResolver) error { + return b.db.Batch(func(tx *bolt.Tx) error { + contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:]) + if err != nil { + return err + } + + for _, resolver := range resolvers { + err = b.writeResolver(contractBucket, resolver) + if err != nil { + return err + } + } + + return nil + }) +} + +// SwapContract performs an atomic swap of the old contract for the new +// contract. This method is used when after a contract has been fully resolved, +// it produces another contract that needs to be resolved. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) SwapContract(oldContract, newContract ContractResolver) error { + return b.db.Batch(func(tx *bolt.Tx) error { + contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:]) + if err != nil { + return err + } + + oldContractkey := oldContract.ResolverKey() + if err := contractBucket.Delete(oldContractkey); err != nil { + return err + } + + return b.writeResolver(contractBucket, newContract) + }) +} + +// ResolveContract marks a contract as fully resolved. Once a contract has been +// fully resolved, it is deleted from persistent storage. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) ResolveContract(res ContractResolver) error { + return b.db.Batch(func(tx *bolt.Tx) error { + contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:]) + if err != nil { + return err + } + + resKey := res.ResolverKey() + return contractBucket.Delete(resKey) + }) +} + +// LogContractResolutions stores a set of chain actions which are derived from +// our set of active contracts, and the on-chain state. We'll write this et of +// cations when: we decide to go on-chain to resolve a contract, or we detect +// that the remote party has gone on-chain. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) LogContractResolutions(c *ContractResolutions) error { + return b.db.Batch(func(tx *bolt.Tx) error { + scopeBucket, err := tx.CreateBucketIfNotExists(b.scopeKey[:]) + if err != nil { + return err + } + + var b bytes.Buffer + + if _, err := b.Write(c.CommitHash[:]); err != nil { + return err + } + + // First, we'll write out the commit output's resolution. + if c.CommitResolution == nil { + if err := binary.Write(&b, endian, false); err != nil { + return err + } + } else { + if err := binary.Write(&b, endian, true); err != nil { + return err + } + err = encodeCommitResolution(&b, c.CommitResolution) + if err != nil { + return err + } + } + + // With the output for the commitment transaction written, we + // can now write out the resolutions for the incoming and + // outgoing HTLC's. + numIncoming := uint32(len(c.HtlcResolutions.IncomingHTLCs)) + if err := binary.Write(&b, endian, numIncoming); err != nil { + return err + } + for _, htlc := range c.HtlcResolutions.IncomingHTLCs { + err := encodeIncomingResolution(&b, &htlc) + if err != nil { + return err + } + } + numOutgoing := uint32(len(c.HtlcResolutions.OutgoingHTLCs)) + if err := binary.Write(&b, endian, numOutgoing); err != nil { + return err + } + for _, htlc := range c.HtlcResolutions.OutgoingHTLCs { + err := encodeOutgoingResolution(&b, &htlc) + if err != nil { + return err + } + } + + return scopeBucket.Put(resolutionsKey, b.Bytes()) + }) +} + +// FetchContractResolutions fetches the set of previously stored contract +// resolutions from persistent storage. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) FetchContractResolutions() (*ContractResolutions, error) { + c := &ContractResolutions{} + err := b.db.View(func(tx *bolt.Tx) error { + scopeBucket := tx.Bucket(b.scopeKey[:]) + if scopeBucket == nil { + return errScopeBucketNoExist + } + + resolutionBytes := scopeBucket.Get(resolutionsKey) + if resolutionBytes == nil { + return errNoResolutions + } + + resReader := bytes.NewReader(resolutionBytes) + + _, err := io.ReadFull(resReader, c.CommitHash[:]) + if err != nil { + return err + } + + // First, we'll attempt to read out the commit resolution (if + // it exists). + var haveCommitRes bool + err = binary.Read(resReader, endian, &haveCommitRes) + if err != nil { + return err + } + if haveCommitRes { + c.CommitResolution = &lnwallet.CommitOutputResolution{} + err = decodeCommitResolution( + resReader, c.CommitResolution, + ) + if err != nil { + return err + } + } + + var ( + numIncoming uint32 + numOutgoing uint32 + ) + + // Next, we'll read out he incoming and outgoing HTLC + // resolutions. + err = binary.Read(resReader, endian, &numIncoming) + if err != nil { + return err + } + c.HtlcResolutions.IncomingHTLCs = make([]lnwallet.IncomingHtlcResolution, numIncoming) + for i := uint32(0); i < numIncoming; i++ { + err := decodeIncomingResolution( + resReader, &c.HtlcResolutions.IncomingHTLCs[i], + ) + if err != nil { + return err + } + } + + err = binary.Read(resReader, endian, &numOutgoing) + if err != nil { + return err + } + c.HtlcResolutions.OutgoingHTLCs = make([]lnwallet.OutgoingHtlcResolution, numOutgoing) + for i := uint32(0); i < numOutgoing; i++ { + err := decodeOutgoingResolution( + resReader, &c.HtlcResolutions.OutgoingHTLCs[i], + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return nil, err + } + + return c, err +} + +// LogChainActions stores a set of chain actions which are derived from our set +// of active contracts, and the on-chain state. We'll write this et of cations +// when: we decide to go on-chain to resolve a contract, or we detect that the +// remote party has gone on-chain. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) LogChainActions(actions ChainActionMap) error { + return b.db.Batch(func(tx *bolt.Tx) error { + scopeBucket, err := tx.CreateBucketIfNotExists(b.scopeKey[:]) + if err != nil { + return err + } + + actionsBucket, err := scopeBucket.CreateBucketIfNotExists( + actionsBucketKey, + ) + if err != nil { + return err + } + + for chainAction, htlcs := range actions { + var htlcBuf bytes.Buffer + err := channeldb.SerializeHtlcs(&htlcBuf, htlcs...) + if err != nil { + return err + } + + actionKey := []byte{byte(chainAction)} + err = actionsBucket.Put(actionKey, htlcBuf.Bytes()) + if err != nil { + return err + } + } + + return nil + }) +} + +// FetchChainActions attempts to fetch the set of previously stored chain +// actions. We'll use this upon restart to properly advance our state machine +// forward. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) FetchChainActions() (ChainActionMap, error) { + actionsMap := make(ChainActionMap) + + err := b.db.View(func(tx *bolt.Tx) error { + scopeBucket := tx.Bucket(b.scopeKey[:]) + if scopeBucket == nil { + return errScopeBucketNoExist + } + + actionsBucket := scopeBucket.Bucket(actionsBucketKey) + if actionsBucket == nil { + return errNoActions + } + + return actionsBucket.ForEach(func(action, htlcBytes []byte) error { + if htlcBytes == nil { + return nil + } + + chainAction := ChainAction(action[0]) + + htlcReader := bytes.NewReader(htlcBytes) + htlcs, err := channeldb.DeserializeHtlcs(htlcReader) + if err != nil { + return err + } + + actionsMap[chainAction] = htlcs + + return nil + }) + }) + if err != nil { + return nil, err + } + + return actionsMap, nil +} + +// WipeHistory is to be called ONLY once *all* contracts have been fully +// resolved, and the channel closure if finalized. This method will delete all +// on-disk state within the persistent log. +// +// NOTE: Part of the ContractResolver interface. +func (b *boltArbitratorLog) WipeHistory() error { + return b.db.Update(func(tx *bolt.Tx) error { + scopeBucket, err := tx.CreateBucketIfNotExists(b.scopeKey[:]) + if err != nil { + return err + } + + // Once we have the main top-level bucket, we'll delete the key + // that stores the state of the arbitrator. + if err := scopeBucket.Delete(stateKey[:]); err != nil { + return err + } + + // Next, we'll delete any lingering contract state within the + // contracts bucket, and the bucket itself once we're done + // clearing it out. + contractBucket, err := scopeBucket.CreateBucketIfNotExists( + contractsBucketKey, + ) + if err != nil { + return err + } + if err := contractBucket.ForEach(func(resKey, _ []byte) error { + return contractBucket.Delete(resKey) + }); err != nil { + return err + } + if err := scopeBucket.DeleteBucket(contractsBucketKey); err != nil { + fmt.Println("nah") + return err + } + + // Next, we'll delete storage of any lingering contract + // resolutions. + if err := scopeBucket.Delete(resolutionsKey); err != nil { + return err + } + + // Before we delta the enclosing bucket itself, we'll delta any + // chain actions that are still stored. + actionsBucket, err := scopeBucket.CreateBucketIfNotExists( + actionsBucketKey, + ) + if err != nil { + return err + } + if err := actionsBucket.ForEach(func(resKey, _ []byte) error { + return actionsBucket.Delete(resKey) + }); err != nil { + return err + } + if err := scopeBucket.DeleteBucket(actionsBucketKey); err != nil { + return err + } + + // Finally, we'll delete the enclosing bucket itself. + return tx.DeleteBucket(b.scopeKey[:]) + }) +} + +// checkpointContract is a private method that will be fed into +// ContractResolver instances to checkpoint their state once they reach +// milestones during contract resolution. +func (b *boltArbitratorLog) checkpointContract(c ContractResolver) error { + return b.db.Batch(func(tx *bolt.Tx) error { + contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:]) + if err != nil { + return err + } + + return b.writeResolver(contractBucket, c) + }) +} + +func encodeIncomingResolution(w io.Writer, i *lnwallet.IncomingHtlcResolution) error { + if _, err := w.Write(i.Preimage[:]); err != nil { + return err + } + + if i.SignedSuccessTx == nil { + if err := binary.Write(w, endian, false); err != nil { + return err + } + } else { + if err := binary.Write(w, endian, true); err != nil { + return err + } + + if err := i.SignedSuccessTx.Serialize(w); err != nil { + return err + } + } + + if err := binary.Write(w, endian, i.CsvDelay); err != nil { + return err + } + if _, err := w.Write(i.ClaimOutpoint.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, endian, i.ClaimOutpoint.Index); err != nil { + return err + } + err := lnwallet.WriteSignDescriptor(w, &i.SweepSignDesc) + if err != nil { + return err + } + + return nil +} + +func decodeIncomingResolution(r io.Reader, h *lnwallet.IncomingHtlcResolution) error { + if _, err := io.ReadFull(r, h.Preimage[:]); err != nil { + return err + } + + var txPresent bool + if err := binary.Read(r, endian, &txPresent); err != nil { + return err + } + if txPresent { + h.SignedSuccessTx = &wire.MsgTx{} + if err := h.SignedSuccessTx.Deserialize(r); err != nil { + return err + } + } + + err := binary.Read(r, endian, &h.CsvDelay) + if err != nil { + return err + } + _, err = io.ReadFull(r, h.ClaimOutpoint.Hash[:]) + if err != nil { + return err + } + err = binary.Read(r, endian, &h.ClaimOutpoint.Index) + if err != nil { + return err + } + + return lnwallet.ReadSignDescriptor(r, &h.SweepSignDesc) +} + +func encodeOutgoingResolution(w io.Writer, o *lnwallet.OutgoingHtlcResolution) error { + if err := binary.Write(w, endian, o.Expiry); err != nil { + return nil + } + + if o.SignedTimeoutTx == nil { + if err := binary.Write(w, endian, false); err != nil { + return err + } + } else { + if err := binary.Write(w, endian, true); err != nil { + return err + } + + if err := o.SignedTimeoutTx.Serialize(w); err != nil { + return err + } + } + + if err := binary.Write(w, endian, o.CsvDelay); err != nil { + return nil + } + if _, err := w.Write(o.ClaimOutpoint.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, endian, o.ClaimOutpoint.Index); err != nil { + return err + } + + return lnwallet.WriteSignDescriptor(w, &o.SweepSignDesc) +} + +func decodeOutgoingResolution(r io.Reader, o *lnwallet.OutgoingHtlcResolution) error { + err := binary.Read(r, endian, &o.Expiry) + if err != nil { + return err + } + + var txPresent bool + if err := binary.Read(r, endian, &txPresent); err != nil { + return err + } + if txPresent { + o.SignedTimeoutTx = &wire.MsgTx{} + if err := o.SignedTimeoutTx.Deserialize(r); err != nil { + return err + } + } + + err = binary.Read(r, endian, &o.CsvDelay) + if err != nil { + return err + } + _, err = io.ReadFull(r, o.ClaimOutpoint.Hash[:]) + if err != nil { + return err + } + err = binary.Read(r, endian, &o.ClaimOutpoint.Index) + if err != nil { + return err + } + + return lnwallet.ReadSignDescriptor(r, &o.SweepSignDesc) +} + +func encodeCommitResolution(w io.Writer, + c *lnwallet.CommitOutputResolution) error { + + if _, err := w.Write(c.SelfOutPoint.Hash[:]); err != nil { + return err + } + err := binary.Write(w, endian, c.SelfOutPoint.Index) + if err != nil { + return err + } + + err = lnwallet.WriteSignDescriptor(w, &c.SelfOutputSignDesc) + if err != nil { + return err + } + + return binary.Write(w, endian, c.MaturityDelay) +} + +func decodeCommitResolution(r io.Reader, + c *lnwallet.CommitOutputResolution) error { + + _, err := io.ReadFull(r, c.SelfOutPoint.Hash[:]) + if err != nil { + return err + } + err = binary.Read(r, endian, &c.SelfOutPoint.Index) + if err != nil { + return err + } + + err = lnwallet.ReadSignDescriptor(r, &c.SelfOutputSignDesc) + if err != nil { + return err + } + + return binary.Read(r, endian, &c.MaturityDelay) +} diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go new file mode 100644 index 00000000..c8ae24b8 --- /dev/null +++ b/contractcourt/briefcase_test.go @@ -0,0 +1,802 @@ +package contractcourt + +import ( + "crypto/rand" + "io/ioutil" + "os" + "reflect" + "testing" + "time" + + prand "math/rand" + + "github.com/boltdb/bolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/roasbeef/btcd/btcec" + "github.com/roasbeef/btcd/chaincfg/chainhash" + "github.com/roasbeef/btcd/txscript" + "github.com/roasbeef/btcd/wire" +) + +var ( + testChainHash = [chainhash.HashSize]byte{ + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, + } + + testChanPoint1 = wire.OutPoint{ + Hash: chainhash.Hash{ + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, + }, + Index: 1, + } + + testChanPoint2 = wire.OutPoint{ + Hash: chainhash.Hash{ + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x2d, 0xe7, 0x93, 0xe4, + }, + Index: 2, + } + + testPreimage = [32]byte{ + 0x52, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, + } + + key1 = []byte{ + 0x04, 0x11, 0xdb, 0x93, 0xe1, 0xdc, 0xdb, 0x8a, + 0x01, 0x6b, 0x49, 0x84, 0x0f, 0x8c, 0x53, 0xbc, 0x1e, + 0xb6, 0x8a, 0x38, 0x2e, 0x97, 0xb1, 0x48, 0x2e, 0xca, + 0xd7, 0xb1, 0x48, 0xa6, 0x90, 0x9a, 0x5c, 0xb2, 0xe0, + 0xea, 0xdd, 0xfb, 0x84, 0xcc, 0xf9, 0x74, 0x44, 0x64, + 0xf8, 0x2e, 0x16, 0x0b, 0xfa, 0x9b, 0x8b, 0x64, 0xf9, + 0xd4, 0xc0, 0x3f, 0x99, 0x9b, 0x86, 0x43, 0xf6, 0x56, + 0xb4, 0x12, 0xa3, + } + + testSignDesc = lnwallet.SignDescriptor{ + SingleTweak: []byte{ + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x02, 0x02, 0x02, 0x02, 0x02, + }, + WitnessScript: []byte{ + 0x00, 0x14, 0xee, 0x91, 0x41, 0x7e, 0x85, 0x6c, 0xde, + 0x10, 0xa2, 0x91, 0x1e, 0xdc, 0xbd, 0xbd, 0x69, 0xe2, + 0xef, 0xb5, 0x71, 0x48, + }, + Output: &wire.TxOut{ + Value: 5000000000, + PkScript: []byte{ + 0x41, // OP_DATA_65 + 0x04, 0xd6, 0x4b, 0xdf, 0xd0, 0x9e, 0xb1, 0xc5, + 0xfe, 0x29, 0x5a, 0xbd, 0xeb, 0x1d, 0xca, 0x42, + 0x81, 0xbe, 0x98, 0x8e, 0x2d, 0xa0, 0xb6, 0xc1, + 0xc6, 0xa5, 0x9d, 0xc2, 0x26, 0xc2, 0x86, 0x24, + 0xe1, 0x81, 0x75, 0xe8, 0x51, 0xc9, 0x6b, 0x97, + 0x3d, 0x81, 0xb0, 0x1c, 0xc3, 0x1f, 0x04, 0x78, + 0x34, 0xbc, 0x06, 0xd6, 0xd6, 0xed, 0xf6, 0x20, + 0xd1, 0x84, 0x24, 0x1a, 0x6a, 0xed, 0x8b, 0x63, + 0xa6, // 65-byte signature + 0xac, // OP_CHECKSIG + }, + }, + HashType: txscript.SigHashAll, + } +) + +func makeTestDB() (*bolt.DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "arblog") + if err != nil { + return nil, nil, err + } + + db, err := bolt.Open(tempDirName+"/test.db", 0600, nil) + if err != nil { + return nil, nil, err + } + + cleanUp := func() { + db.Close() + os.RemoveAll(tempDirName) + } + + return db, cleanUp, nil +} + +func newTestBoltArbLog(chainhash chainhash.Hash, + op wire.OutPoint) (ArbitratorLog, func(), error) { + + testDB, cleanUp, err := makeTestDB() + if err != nil { + return nil, nil, err + } + + testArbCfg := ChannelArbitratorConfig{} + testLog, err := newBoltArbitratorLog(testDB, testArbCfg, chainhash, op) + if err != nil { + return nil, nil, err + } + + return testLog, cleanUp, err +} + +func randOutPoint() wire.OutPoint { + var op wire.OutPoint + rand.Read(op.Hash[:]) + op.Index = prand.Uint32() + + return op +} + +func assertResolversEqual(t *testing.T, originalResolver ContractResolver, + diskResolver ContractResolver) { + + assertTimeoutResEqual := func(ogRes, diskRes *htlcTimeoutResolver) { + if !reflect.DeepEqual(ogRes.htlcResolution, diskRes.htlcResolution) { + t.Fatalf("resolution mismatch: expected %#v, got %v#", + ogRes.htlcResolution, diskRes.htlcResolution) + } + if ogRes.outputIncubating != diskRes.outputIncubating { + t.Fatalf("expected %v, got %v", + ogRes.outputIncubating, diskRes.outputIncubating) + } + if ogRes.resolved != diskRes.resolved { + t.Fatalf("expected %v, got %v", ogRes.resolved, + diskRes.resolved) + } + if ogRes.broadcastHeight != diskRes.broadcastHeight { + t.Fatalf("expected %v, got %v", + ogRes.broadcastHeight, diskRes.broadcastHeight) + } + if ogRes.htlcIndex != diskRes.htlcIndex { + t.Fatalf("expected %v, got %v", ogRes.htlcIndex, + diskRes.htlcIndex) + } + } + + assertSuccessResEqual := func(ogRes, diskRes *htlcSuccessResolver) { + if !reflect.DeepEqual(ogRes.htlcResolution, diskRes.htlcResolution) { + t.Fatalf("resolution mismatch: expected %#v, got %v#", + ogRes.htlcResolution, diskRes.htlcResolution) + } + if ogRes.outputIncubating != diskRes.outputIncubating { + t.Fatalf("expected %v, got %v", + ogRes.outputIncubating, diskRes.outputIncubating) + } + if ogRes.resolved != diskRes.resolved { + t.Fatalf("expected %v, got %v", ogRes.resolved, + diskRes.resolved) + } + if ogRes.broadcastHeight != diskRes.broadcastHeight { + t.Fatalf("expected %v, got %v", + ogRes.broadcastHeight, diskRes.broadcastHeight) + } + if ogRes.payHash != diskRes.payHash { + t.Fatalf("expected %v, got %v", ogRes.payHash, + diskRes.payHash) + } + } + + switch ogRes := originalResolver.(type) { + case *htlcTimeoutResolver: + diskRes := diskResolver.(*htlcTimeoutResolver) + assertTimeoutResEqual(ogRes, diskRes) + + case *htlcSuccessResolver: + diskRes := diskResolver.(*htlcSuccessResolver) + assertSuccessResEqual(ogRes, diskRes) + + case *htlcOutgoingContestResolver: + diskRes := diskResolver.(*htlcOutgoingContestResolver) + assertTimeoutResEqual( + &ogRes.htlcTimeoutResolver, &diskRes.htlcTimeoutResolver, + ) + + case *htlcIncomingContestResolver: + diskRes := diskResolver.(*htlcIncomingContestResolver) + assertSuccessResEqual( + &ogRes.htlcSuccessResolver, &diskRes.htlcSuccessResolver, + ) + + if ogRes.htlcExpiry != diskRes.htlcExpiry { + t.Fatalf("expected %v, got %v", ogRes.htlcExpiry, + diskRes.htlcExpiry) + } + + case *commitSweepResolver: + diskRes := diskResolver.(*commitSweepResolver) + if !reflect.DeepEqual(ogRes.commitResolution, diskRes.commitResolution) { + t.Fatalf("resolution mismatch: expected %v, got %v", + ogRes.commitResolution, diskRes.commitResolution) + } + if ogRes.resolved != diskRes.resolved { + t.Fatalf("expected %v, got %v", ogRes.resolved, + diskRes.resolved) + } + if ogRes.broadcastHeight != diskRes.broadcastHeight { + t.Fatalf("expected %v, got %v", + ogRes.broadcastHeight, diskRes.broadcastHeight) + } + if ogRes.chanPoint != diskRes.chanPoint { + t.Fatalf("expected %v, got %v", ogRes.chanPoint, + diskRes.chanPoint) + } + } +} + +// TestContractInsertionRetrieval tests that were able to insert a set of +// unresolved contracts into the log, and retrieve the same set properly. +func TestContractInsertionRetrieval(t *testing.T) { + t.Parallel() + + // First, we'll create a test instance of the ArbitratorLog + // implementation backed by boltdb. + testLog, cleanUp, err := newTestBoltArbLog( + testChainHash, testChanPoint1, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp() + + // The log created, we'll create a series of resolvers, each properly + // implementing the ContractResolver interface. + timeoutResolver := htlcTimeoutResolver{ + htlcResolution: lnwallet.OutgoingHtlcResolution{ + Expiry: 99, + SignedTimeoutTx: nil, + CsvDelay: 99, + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + outputIncubating: true, + resolved: true, + broadcastHeight: 102, + htlcIndex: 12, + } + successResolver := htlcSuccessResolver{ + htlcResolution: lnwallet.IncomingHtlcResolution{ + Preimage: testPreimage, + SignedSuccessTx: nil, + CsvDelay: 900, + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + outputIncubating: true, + resolved: true, + broadcastHeight: 109, + payHash: testPreimage, + sweepTx: nil, + } + resolvers := []ContractResolver{ + &timeoutResolver, + &successResolver, + &commitSweepResolver{ + commitResolution: lnwallet.CommitOutputResolution{ + SelfOutPoint: testChanPoint2, + SelfOutputSignDesc: testSignDesc, + MaturityDelay: 99, + }, + resolved: false, + broadcastHeight: 109, + chanPoint: testChanPoint1, + sweepTx: nil, + }, + } + + // All resolvers require a unique ResolverKey() output. To achieve this + // for the composite resolvers, we'll mutate the underlying resolver + // with a new outpoint. + contestTimeout := timeoutResolver + contestTimeout.htlcResolution.ClaimOutpoint = randOutPoint() + resolvers = append(resolvers, &htlcOutgoingContestResolver{ + htlcTimeoutResolver: contestTimeout, + }) + contestSuccess := successResolver + contestSuccess.htlcResolution.ClaimOutpoint = randOutPoint() + resolvers = append(resolvers, &htlcIncomingContestResolver{ + htlcExpiry: 100, + htlcSuccessResolver: contestSuccess, + }) + + // For quick lookup during the test, we'll create this map which allow + // us to lookup a resolver according to its unique resolver key. + resolverMap := make(map[string]ContractResolver) + resolverMap[string(timeoutResolver.ResolverKey())] = resolvers[0] + resolverMap[string(successResolver.ResolverKey())] = resolvers[1] + resolverMap[string(resolvers[2].ResolverKey())] = resolvers[2] + resolverMap[string(resolvers[3].ResolverKey())] = resolvers[3] + resolverMap[string(resolvers[4].ResolverKey())] = resolvers[4] + + // Now, we'll insert the resolver into the log. + if err := testLog.InsertUnresolvedContracts(resolvers...); err != nil { + t.Fatalf("unable to insert resolvers: %v", err) + } + + // With the resolvers inserted, we'll now attempt to retrieve them from + // the database, so we can compare them to the versions we created + // above. + diskResolvers, err := testLog.FetchUnresolvedContracts() + if err != nil { + t.Fatalf("unable to retrieve resolvers: %v", err) + } + + if len(diskResolvers) != len(resolvers) { + t.Fatalf("expected %v got resolvers, instead got %v: %#v", + len(resolvers), len(diskResolvers), + diskResolvers) + } + + // Now we'll run through each of the resolvers, and ensure that it maps + // to a resolver perfectly that we inserted previously. + for _, diskResolver := range diskResolvers { + resKey := string(diskResolver.ResolverKey()) + originalResolver, ok := resolverMap[resKey] + if !ok { + t.Fatalf("unable to find resolver match for %T: %v", + diskResolver, resKey) + } + + assertResolversEqual(t, originalResolver, diskResolver) + } + + // We'll now delete the state, then attempt to retrieve the set of + // resolvers, no resolvers should be found. + if err := testLog.WipeHistory(); err != nil { + t.Fatalf("unable to wipe log: %v", err) + } + diskResolvers, err = testLog.FetchUnresolvedContracts() + if len(diskResolvers) != 0 { + t.Fatalf("no resolvers should be found, instead %v were", + len(diskResolvers)) + } +} + +// TestContractResolution tests that once we mark a contract as resolved, it's +// properly removed from the database. +func TestContractResolution(t *testing.T) { + t.Parallel() + + // First, we'll create a test instance of the ArbitratorLog + // implementation backed by boltdb. + testLog, cleanUp, err := newTestBoltArbLog( + testChainHash, testChanPoint1, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp() + + // We'll now create a timeout resolver that we'll be using for the + // duration of this test. + timeoutResolver := &htlcTimeoutResolver{ + htlcResolution: lnwallet.OutgoingHtlcResolution{ + Expiry: 991, + SignedTimeoutTx: nil, + CsvDelay: 992, + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + outputIncubating: true, + resolved: true, + broadcastHeight: 192, + htlcIndex: 9912, + } + + // First, we'll insert the resolver into the database and ensure that + // we get the same resolver out the other side. + err = testLog.InsertUnresolvedContracts(timeoutResolver) + if err != nil { + t.Fatalf("unable to insert contract into db: %v", err) + } + dbContracts, err := testLog.FetchUnresolvedContracts() + if err != nil { + t.Fatalf("unable to fetch contracts from db: %v", err) + } + assertResolversEqual(t, timeoutResolver, dbContracts[0]) + + // Now, we'll mark the contract as resolved within the database. + if err := testLog.ResolveContract(timeoutResolver); err != nil { + t.Fatalf("unable to resolve contract: %v", err) + } + + // At this point, no contracts should exist within the log. + dbContracts, err = testLog.FetchUnresolvedContracts() + if err != nil { + t.Fatalf("unable to fetch contracts from db: %v", err) + } + if len(dbContracts) != 0 { + t.Fatalf("no contract should be from in the db, instead %v "+ + "were", len(dbContracts)) + } +} + +// TestContractSwapping ensures that callers are able to atomically swap to +// distinct contracts for one another. +func TestContractSwapping(t *testing.T) { + t.Parallel() + + // First, we'll create a test instance of the ArbitratorLog + // implementation backed by boltdb. + testLog, cleanUp, err := newTestBoltArbLog( + testChainHash, testChanPoint1, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp() + + // We'll create two resolvers, a regular timeout resolver, and the + // contest resolver that eventually turns into the timeout resolver. + timeoutResolver := htlcTimeoutResolver{ + htlcResolution: lnwallet.OutgoingHtlcResolution{ + Expiry: 99, + SignedTimeoutTx: nil, + CsvDelay: 99, + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + outputIncubating: true, + resolved: true, + broadcastHeight: 102, + htlcIndex: 12, + } + contestResolver := &htlcOutgoingContestResolver{ + htlcTimeoutResolver: timeoutResolver, + } + + // We'll first insert the contest resolver into the log. + err = testLog.InsertUnresolvedContracts(contestResolver) + if err != nil { + t.Fatalf("unable to insert contract into db: %v", err) + } + + // With the resolver inserted, we'll now attempt to atomically swap it + // for its underlying timeout resolver. + err = testLog.SwapContract(contestResolver, &timeoutResolver) + if err != nil { + t.Fatalf("unable to swap contracts: %v", err) + } + + // At this point, there should now only be a single contract in the + // database. + dbContracts, err := testLog.FetchUnresolvedContracts() + if err != nil { + t.Fatalf("unable to fetch contracts from db: %v", err) + } + if len(dbContracts) != 1 { + t.Fatalf("one contract should be from in the db, instead %v "+ + "were", len(dbContracts)) + } + + // That single contract should be the underlying timeout resolver. + assertResolversEqual(t, &timeoutResolver, dbContracts[0]) +} + +// TestContractResolutionsStorage tests that we're able to properly store and +// retrieve contract resolutions written to disk. +func TestContractResolutionsStorage(t *testing.T) { + t.Parallel() + + // First, we'll create a test instance of the ArbitratorLog + // implementation backed by boltdb. + testLog, cleanUp, err := newTestBoltArbLog( + testChainHash, testChanPoint1, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp() + + // With the test log created, we'll now craft a contact resolution that + // will be using for the duration of this test. + res := ContractResolutions{ + CommitHash: testChainHash, + CommitResolution: &lnwallet.CommitOutputResolution{ + SelfOutPoint: testChanPoint2, + SelfOutputSignDesc: testSignDesc, + MaturityDelay: 101, + }, + HtlcResolutions: lnwallet.HtlcResolutions{ + IncomingHTLCs: []lnwallet.IncomingHtlcResolution{ + { + Preimage: testPreimage, + SignedSuccessTx: nil, + CsvDelay: 900, + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + }, + OutgoingHTLCs: []lnwallet.OutgoingHtlcResolution{ + { + Expiry: 103, + SignedTimeoutTx: nil, + CsvDelay: 923923, + ClaimOutpoint: randOutPoint(), + SweepSignDesc: testSignDesc, + }, + }, + }, + } + + // Insert the resolution into the database, then immediately retrieve + // them so we can compare equality against the original version. + if err := testLog.LogContractResolutions(&res); err != nil { + t.Fatalf("unable to insert resolutions into db: %v", err) + } + diskRes, err := testLog.FetchContractResolutions() + if err != nil { + t.Fatalf("unable to read resolution from db: %v", err) + } + + if !reflect.DeepEqual(&res, diskRes) { + t.Fatalf("resolution mismatch: expected %#v\n, got %#v", + &res, diskRes) + } + + // We'll now delete the state, then attempt to retrieve the set of + // resolvers, no resolutions should be found. + if err := testLog.WipeHistory(); err != nil { + t.Fatalf("unable to wipe log: %v", err) + } + _, err = testLog.FetchContractResolutions() + if err != errScopeBucketNoExist { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestChainActionStorage tests that were able to properly store a set of chain +// actions, and then retrieve the same set of chain actions from disk. +func TestChainActionStorage(t *testing.T) { + t.Parallel() + + // First, we'll create a test instance of the ArbitratorLog + // implementation backed by boltdb. + testLog, cleanUp, err := newTestBoltArbLog( + testChainHash, testChanPoint2, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp() + + chainActions := ChainActionMap{ + NoAction: []channeldb.HTLC{ + { + RHash: testPreimage, + Amt: lnwire.MilliSatoshi(prand.Uint64()), + RefundTimeout: prand.Uint32(), + OutputIndex: int32(prand.Uint32()), + Incoming: true, + HtlcIndex: prand.Uint64(), + LogIndex: prand.Uint64(), + OnionBlob: make([]byte, 0), + Signature: make([]byte, 0), + }, + }, + HtlcTimeoutAction: []channeldb.HTLC{ + { + RHash: testPreimage, + Amt: lnwire.MilliSatoshi(prand.Uint64()), + RefundTimeout: prand.Uint32(), + OutputIndex: int32(prand.Uint32()), + Incoming: true, + HtlcIndex: prand.Uint64(), + LogIndex: prand.Uint64(), + OnionBlob: make([]byte, 0), + Signature: make([]byte, 0), + }, + }, + HtlcClaimAction: []channeldb.HTLC{ + { + RHash: testPreimage, + Amt: lnwire.MilliSatoshi(prand.Uint64()), + RefundTimeout: prand.Uint32(), + OutputIndex: int32(prand.Uint32()), + Incoming: true, + HtlcIndex: prand.Uint64(), + LogIndex: prand.Uint64(), + OnionBlob: make([]byte, 0), + Signature: make([]byte, 0), + }, + }, + HtlcFailNowAction: []channeldb.HTLC{ + { + RHash: testPreimage, + Amt: lnwire.MilliSatoshi(prand.Uint64()), + RefundTimeout: prand.Uint32(), + OutputIndex: int32(prand.Uint32()), + Incoming: true, + HtlcIndex: prand.Uint64(), + LogIndex: prand.Uint64(), + OnionBlob: make([]byte, 0), + Signature: make([]byte, 0), + }, + }, + HtlcOutgoingWatchAction: []channeldb.HTLC{ + { + RHash: testPreimage, + Amt: lnwire.MilliSatoshi(prand.Uint64()), + RefundTimeout: prand.Uint32(), + OutputIndex: int32(prand.Uint32()), + Incoming: true, + HtlcIndex: prand.Uint64(), + LogIndex: prand.Uint64(), + OnionBlob: make([]byte, 0), + Signature: make([]byte, 0), + }, + }, + HtlcIncomingWatchAction: []channeldb.HTLC{ + { + RHash: testPreimage, + Amt: lnwire.MilliSatoshi(prand.Uint64()), + RefundTimeout: prand.Uint32(), + OutputIndex: int32(prand.Uint32()), + Incoming: true, + HtlcIndex: prand.Uint64(), + LogIndex: prand.Uint64(), + OnionBlob: make([]byte, 0), + Signature: make([]byte, 0), + }, + }, + } + + // With our set of test chain actions constructed, we'll now insert + // them into the database, retrieve them, then assert equality with the + // set of chain actions create above. + if err := testLog.LogChainActions(chainActions); err != nil { + t.Fatalf("unable to write chain actions: %v", err) + } + diskActions, err := testLog.FetchChainActions() + if err != nil { + t.Fatalf("unable to read chain actions: %v", err) + } + + for k, contracts := range chainActions { + diskContracts := diskActions[k] + if !reflect.DeepEqual(contracts, diskContracts) { + t.Fatalf("chain action mismatch: expected %v, got %v", + spew.Sdump(contracts), spew.Sdump(diskContracts)) + } + } + + // We'll now delete the state, then attempt to retrieve the set of + // chain actions, no resolutions should be found. + if err := testLog.WipeHistory(); err != nil { + t.Fatalf("unable to wipe log: %v", err) + } + actions, err := testLog.FetchChainActions() + if len(actions) != 0 { + t.Fatalf("expected no chain actions, instead found: %v", + len(actions)) + } +} + +// TestStateMutation tests that we're able to properly mutate the state of the +// log, then retrieve that same mutated state from disk. +func TestStateMutation(t *testing.T) { + t.Parallel() + + testLog, cleanUp, err := newTestBoltArbLog( + testChainHash, testChanPoint1, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp() + + // The default state of an arbitrator should be StateDefault. + arbState, err := testLog.CurrentState() + if err != nil { + t.Fatalf("unable to read arb state: %v", err) + } + if arbState != StateDefault { + t.Fatalf("state mismatch: expected %v, got %v", StateDefault, + arbState) + } + + // We should now be able to mutate the state to an arbitrary one of our + // choosing, then read that same state back from disk. + if err := testLog.CommitState(StateFullyResolved); err != nil { + t.Fatalf("unable to write state: %v", err) + } + arbState, err = testLog.CurrentState() + if err != nil { + t.Fatalf("unable to read arb state: %v", err) + } + if arbState != StateFullyResolved { + t.Fatalf("state mismatch: expected %v, got %v", StateFullyResolved, + arbState) + } + + // Next, we'll wipe our state and ensure that if we try to query for + // the current state, we get the proper error. + err = testLog.WipeHistory() + if err != nil { + t.Fatalf("unable to wipe history: %v", err) + } + + // If we try to query for the state again, we should get the default + // state again. + arbState, err = testLog.CurrentState() + if arbState != StateDefault { + t.Fatalf("state mismatch: expected %v, got %v", StateDefault, + arbState) + } +} + +// TestScopeIsolation tests the two distinct ArbitratorLog instances with two +// distinct scopes, don't over write the state of one another. +func TestScopeIsolation(t *testing.T) { + t.Parallel() + + // We'll create two distinct test logs. Each log will have a unique + // scope key, and therefore should be isolated from the other on disk. + testLog1, cleanUp1, err := newTestBoltArbLog( + testChainHash, testChanPoint1, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp1() + + testLog2, cleanUp2, err := newTestBoltArbLog( + testChainHash, testChanPoint2, + ) + if err != nil { + t.Fatalf("unable to create test log: %v", err) + } + defer cleanUp2() + + // We'll now update the current state of both the logs to a unique + // state. + if err := testLog1.CommitState(StateWaitingFullResolution); err != nil { + t.Fatalf("unable to write state: %v", err) + } + if err := testLog2.CommitState(StateContractClosed); err != nil { + t.Fatalf("unable to write state: %v", err) + } + + // Querying each log, the states should be the prior one we set, and be + // disjoint. + log1State, err := testLog1.CurrentState() + if err != nil { + t.Fatalf("unable to read arb state: %v", err) + } + log2State, err := testLog2.CurrentState() + if err != nil { + t.Fatalf("unable to read arb state: %v", err) + } + + if log1State == log2State { + t.Fatalf("log states are the same: %v", log1State) + } + + if log1State != StateWaitingFullResolution { + t.Fatalf("state mismatch: expected %v, got %v", + StateWaitingFullResolution, log1State) + } + if log2State != StateContractClosed { + t.Fatalf("state mismatch: expected %v, got %v", + StateContractClosed, log2State) + } +} + +func init() { + testSignDesc.PubKey, _ = btcec.ParsePubKey(key1, btcec.S256()) + + prand.Seed(time.Now().Unix()) +}