Merge pull request #3665 from joostjager/resolver-constructors

cnct: add resolver constructors
This commit is contained in:
Joost Jager 2019-11-12 14:10:39 +01:00 committed by GitHub
commit 76c2b2cea2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 266 additions and 234 deletions

@ -440,7 +440,7 @@ func (b *boltArbitratorLog) CommitState(s ArbitratorState) error {
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, error) { func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, error) {
resKit := ResolverKit{ resolverCfg := ResolverConfig{
ChannelArbitratorConfig: b.cfg, ChannelArbitratorConfig: b.cfg,
Checkpoint: b.checkpointContract, Checkpoint: b.checkpointContract,
} }
@ -469,56 +469,38 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro
switch resType { switch resType {
case resolverTimeout: case resolverTimeout:
timeoutRes := &htlcTimeoutResolver{} res, err = newTimeoutResolverFromReader(
if err := timeoutRes.Decode(resReader); err != nil { resReader, resolverCfg,
return err )
}
timeoutRes.AttachResolverKit(resKit)
res = timeoutRes
case resolverSuccess: case resolverSuccess:
successRes := &htlcSuccessResolver{} res, err = newSuccessResolverFromReader(
if err := successRes.Decode(resReader); err != nil { resReader, resolverCfg,
return err )
}
res = successRes
case resolverOutgoingContest: case resolverOutgoingContest:
outContestRes := &htlcOutgoingContestResolver{ res, err = newOutgoingContestResolverFromReader(
htlcTimeoutResolver: htlcTimeoutResolver{}, resReader, resolverCfg,
} )
if err := outContestRes.Decode(resReader); err != nil {
return err
}
res = outContestRes
case resolverIncomingContest: case resolverIncomingContest:
inContestRes := &htlcIncomingContestResolver{ res, err = newIncomingContestResolverFromReader(
htlcSuccessResolver: htlcSuccessResolver{}, resReader, resolverCfg,
} )
if err := inContestRes.Decode(resReader); err != nil {
return err
}
res = inContestRes
case resolverUnilateralSweep: case resolverUnilateralSweep:
sweepRes := &commitSweepResolver{} res, err = newCommitSweepResolverFromReader(
if err := sweepRes.Decode(resReader); err != nil { resReader, resolverCfg,
return err )
}
res = sweepRes
default: default:
return fmt.Errorf("unknown resolver type: %v", resType) return fmt.Errorf("unknown resolver type: %v", resType)
} }
resKit.Quit = make(chan struct{}) if err != nil {
res.AttachResolverKit(resKit) return err
}
contracts = append(contracts, res) contracts = append(contracts, res)
return nil return nil
}) })

@ -1685,7 +1685,7 @@ func (c *ChannelArbitrator) prepContractResolutions(
// We'll create the resolver kit that we'll be cloning for each // We'll create the resolver kit that we'll be cloning for each
// resolver so they each can do their duty. // resolver so they each can do their duty.
resKit := ResolverKit{ resolverCfg := ResolverConfig{
ChannelArbitratorConfig: c.cfg, ChannelArbitratorConfig: c.cfg,
Checkpoint: func(res ContractResolver) error { Checkpoint: func(res ContractResolver) error {
return c.log.InsertUnresolvedContracts(res) return c.log.InsertUnresolvedContracts(res)
@ -1733,14 +1733,10 @@ func (c *ChannelArbitrator) prepContractResolutions(
continue continue
} }
resKit.Quit = make(chan struct{}) resolver := newSuccessResolver(
resolver := &htlcSuccessResolver{ resolution, height,
htlcResolution: resolution, htlc.RHash, htlc.Amt, resolverCfg,
broadcastHeight: height, )
payHash: htlc.RHash,
htlcAmt: htlc.Amt,
ResolverKit: resKit,
}
htlcResolvers = append(htlcResolvers, resolver) htlcResolvers = append(htlcResolvers, resolver)
} }
@ -1761,14 +1757,10 @@ func (c *ChannelArbitrator) prepContractResolutions(
continue continue
} }
resKit.Quit = make(chan struct{}) resolver := newTimeoutResolver(
resolver := &htlcTimeoutResolver{ resolution, height, htlc.HtlcIndex,
htlcResolution: resolution, htlc.Amt, resolverCfg,
broadcastHeight: height, )
htlcIndex: htlc.HtlcIndex,
htlcAmt: htlc.Amt,
ResolverKit: resKit,
}
htlcResolvers = append(htlcResolvers, resolver) htlcResolvers = append(htlcResolvers, resolver)
} }
@ -1798,18 +1790,11 @@ func (c *ChannelArbitrator) prepContractResolutions(
ChanID: c.cfg.ShortChanID, ChanID: c.cfg.ShortChanID,
} }
resKit.Quit = make(chan struct{}) resolver := newIncomingContestResolver(
resolver := &htlcIncomingContestResolver{ htlc.RefundTimeout, circuitKey,
htlcExpiry: htlc.RefundTimeout, resolution, height, htlc.RHash,
circuitKey: circuitKey, htlc.Amt, resolverCfg,
htlcSuccessResolver: htlcSuccessResolver{ )
htlcResolution: resolution,
broadcastHeight: height,
payHash: htlc.RHash,
htlcAmt: htlc.Amt,
ResolverKit: resKit,
},
}
htlcResolvers = append(htlcResolvers, resolver) htlcResolvers = append(htlcResolvers, resolver)
} }
@ -1831,16 +1816,10 @@ func (c *ChannelArbitrator) prepContractResolutions(
continue continue
} }
resKit.Quit = make(chan struct{}) resolver := newOutgoingContestResolver(
resolver := &htlcOutgoingContestResolver{ resolution, height, htlc.HtlcIndex,
htlcTimeoutResolver: htlcTimeoutResolver{ htlc.Amt, resolverCfg,
htlcResolution: resolution, )
broadcastHeight: height,
htlcIndex: htlc.HtlcIndex,
htlcAmt: htlc.Amt,
ResolverKit: resKit,
},
}
htlcResolvers = append(htlcResolvers, resolver) htlcResolvers = append(htlcResolvers, resolver)
} }
} }
@ -1850,14 +1829,10 @@ func (c *ChannelArbitrator) prepContractResolutions(
// a resolver to sweep our commitment output (but only if it wasn't // a resolver to sweep our commitment output (but only if it wasn't
// trimmed). // trimmed).
if contractResolutions.CommitResolution != nil { if contractResolutions.CommitResolution != nil {
resKit.Quit = make(chan struct{}) resolver := newCommitSweepResolver(
resolver := &commitSweepResolver{ *contractResolutions.CommitResolution,
commitResolution: *contractResolutions.CommitResolution, height, c.cfg.ChanPoint, resolverCfg,
broadcastHeight: height, )
chanPoint: c.cfg.ChanPoint,
ResolverKit: resKit,
}
htlcResolvers = append(htlcResolvers, resolver) htlcResolvers = append(htlcResolvers, resolver)
} }

@ -36,7 +36,20 @@ type commitSweepResolver struct {
// chanPoint is the channel point of the original contract. // chanPoint is the channel point of the original contract.
chanPoint wire.OutPoint chanPoint wire.OutPoint
ResolverKit contractResolverKit
}
// newCommitSweepResolver instantiates a new direct commit output resolver.
func newCommitSweepResolver(res lnwallet.CommitOutputResolution,
broadcastHeight uint32,
chanPoint wire.OutPoint, resCfg ResolverConfig) *commitSweepResolver {
return &commitSweepResolver{
contractResolverKit: *newContractResolverKit(resCfg),
commitResolution: res,
broadcastHeight: broadcastHeight,
chanPoint: chanPoint,
}
} }
// ResolverKey returns an identifier which should be globally unique for this // ResolverKey returns an identifier which should be globally unique for this
@ -80,7 +93,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
case <-c.Quit: case <-c.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -138,7 +151,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
log.Infof("ChannelPoint(%v) commit tx is fully resolved by "+ log.Infof("ChannelPoint(%v) commit tx is fully resolved by "+
"sweep tx: %v", c.chanPoint, sweepResult.Tx.TxHash()) "sweep tx: %v", c.chanPoint, sweepResult.Tx.TxHash())
case <-c.Quit: case <-c.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -180,7 +193,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
log.Errorf("unable to Checkpoint: %v", err) log.Errorf("unable to Checkpoint: %v", err)
return nil, err return nil, err
} }
case <-c.Quit: case <-c.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -206,7 +219,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
log.Infof("ChannelPoint(%v) commit tx is fully resolved, at height: %v", log.Infof("ChannelPoint(%v) commit tx is fully resolved, at height: %v",
c.chanPoint, confInfo.BlockHeight) c.chanPoint, confInfo.BlockHeight)
case <-c.Quit: case <-c.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -221,7 +234,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) Stop() { func (c *commitSweepResolver) Stop() {
close(c.Quit) close(c.quit)
} }
// IsResolved returns true if the stored state in the resolve is fully // IsResolved returns true if the stored state in the resolve is fully
@ -262,44 +275,40 @@ func (c *commitSweepResolver) Encode(w io.Writer) error {
return nil return nil
} }
// Decode attempts to decode an encoded ContractResolver from the passed Reader // newCommitSweepResolverFromReader attempts to decode an encoded
// instance, returning an active ContractResolver instance. // ContractResolver from the passed Reader instance, returning an active
// // ContractResolver instance.
// NOTE: Part of the ContractResolver interface. func newCommitSweepResolverFromReader(r io.Reader, resCfg ResolverConfig) (
func (c *commitSweepResolver) Decode(r io.Reader) error { *commitSweepResolver, error) {
c := &commitSweepResolver{
contractResolverKit: *newContractResolverKit(resCfg),
}
if err := decodeCommitResolution(r, &c.commitResolution); err != nil { if err := decodeCommitResolution(r, &c.commitResolution); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &c.resolved); err != nil { if err := binary.Read(r, endian, &c.resolved); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &c.broadcastHeight); err != nil { if err := binary.Read(r, endian, &c.broadcastHeight); err != nil {
return err return nil, err
} }
_, err := io.ReadFull(r, c.chanPoint.Hash[:]) _, err := io.ReadFull(r, c.chanPoint.Hash[:])
if err != nil { if err != nil {
return err return nil, err
} }
err = binary.Read(r, endian, &c.chanPoint.Index) err = binary.Read(r, endian, &c.chanPoint.Index)
if err != nil { if err != nil {
return err return nil, err
} }
// Previously a sweep tx was deserialized at this point. Refactoring // Previously a sweep tx was deserialized at this point. Refactoring
// removed this, but keep in mind that this data may still be present in // removed this, but keep in mind that this data may still be present in
// the database. // the database.
return nil return c, nil
}
// AttachResolverKit should be called once a resolved is successfully decoded
// from its stored format. This struct delivers a generic tool kit that
// resolvers need to complete their duty.
//
// NOTE: Part of the ContractResolver interface.
func (c *commitSweepResolver) AttachResolverKit(r ResolverKit) {
c.ResolverKit = r
} }
// A compile time assertion to ensure commitSweepResolver meets the // A compile time assertion to ensure commitSweepResolver meets the

@ -46,16 +46,6 @@ type ContractResolver interface {
// passed Writer. // passed Writer.
Encode(w io.Writer) error Encode(w io.Writer) error
// Decode attempts to decode an encoded ContractResolver from the
// passed Reader instance, returning an active ContractResolver
// instance.
Decode(r io.Reader) error
// AttachResolverKit should be called once a resolved is successfully
// decoded from its stored format. This struct delivers a generic tool
// kit that resolvers need to complete their duty.
AttachResolverKit(ResolverKit)
// Stop signals the resolver to cancel any current resolution // Stop signals the resolver to cancel any current resolution
// processes, and suspend. // processes, and suspend.
Stop() Stop()
@ -69,10 +59,9 @@ type reportingContractResolver interface {
report() *ContractReport report() *ContractReport
} }
// ResolverKit is meant to be used as a mix-in struct to be embedded within a // ResolverConfig contains the externally supplied configuration items that are
// given ContractResolver implementation. It contains all the items that a // required by a ContractResolver implementation.
// resolver requires to carry out its duties. type ResolverConfig struct {
type ResolverKit struct {
// ChannelArbitratorConfig contains all the interfaces and closures // ChannelArbitratorConfig contains all the interfaces and closures
// required for the resolver to interact with outside sub-systems. // required for the resolver to interact with outside sub-systems.
ChannelArbitratorConfig ChannelArbitratorConfig
@ -81,8 +70,23 @@ type ResolverKit struct {
// should write the state of the resolver to persistent storage, and // should write the state of the resolver to persistent storage, and
// return a non-nil error upon success. // return a non-nil error upon success.
Checkpoint func(ContractResolver) error Checkpoint func(ContractResolver) error
}
Quit chan struct{} // contractResolverKit is meant to be used as a mix-in struct to be embedded within a
// given ContractResolver implementation. It contains all the common items that
// a resolver requires to carry out its duties.
type contractResolverKit struct {
ResolverConfig
quit chan struct{}
}
// newContractResolverKit instantiates the mix-in struct.
func newContractResolverKit(cfg ResolverConfig) *contractResolverKit {
return &contractResolverKit{
ResolverConfig: cfg,
quit: make(chan struct{}),
}
} }
var ( var (

@ -5,11 +5,12 @@ import (
"errors" "errors"
"io" "io"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/invoices"
"github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
) )
// htlcIncomingContestResolver is a ContractResolver that's able to resolve an // htlcIncomingContestResolver is a ContractResolver that's able to resolve an
@ -34,6 +35,24 @@ type htlcIncomingContestResolver struct {
htlcSuccessResolver htlcSuccessResolver
} }
// newIncomingContestResolver instantiates a new incoming htlc contest resolver.
func newIncomingContestResolver(htlcExpiry uint32,
circuitKey channeldb.CircuitKey, res lnwallet.IncomingHtlcResolution,
broadcastHeight uint32, payHash lntypes.Hash,
htlcAmt lnwire.MilliSatoshi,
resCfg ResolverConfig) *htlcIncomingContestResolver {
success := newSuccessResolver(
res, broadcastHeight, payHash, htlcAmt, resCfg,
)
return &htlcIncomingContestResolver{
htlcExpiry: htlcExpiry,
circuitKey: circuitKey,
htlcSuccessResolver: *success,
}
}
// Resolve attempts to resolve this contract. As we don't yet know of the // Resolve attempts to resolve this contract. As we don't yet know of the
// preimage for the contract, we'll wait for one of two things to happen: // preimage for the contract, we'll wait for one of two things to happen:
// //
@ -68,7 +87,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
currentHeight = newBlock.Height currentHeight = newBlock.Height
case <-h.Quit: case <-h.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -239,7 +258,7 @@ func (h *htlcIncomingContestResolver) Resolve() (ContractResolver, error) {
return nil, h.Checkpoint(h) return nil, h.Checkpoint(h)
} }
case <-h.Quit: case <-h.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
} }
@ -271,7 +290,7 @@ func (h *htlcIncomingContestResolver) report() *ContractReport {
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) Stop() { func (h *htlcIncomingContestResolver) Stop() {
close(h.Quit) close(h.quit)
} }
// IsResolved returns true if the stored state in the resolve is fully // IsResolved returns true if the stored state in the resolve is fully
@ -296,27 +315,27 @@ func (h *htlcIncomingContestResolver) Encode(w io.Writer) error {
return h.htlcSuccessResolver.Encode(w) return h.htlcSuccessResolver.Encode(w)
} }
// Decode attempts to decode an encoded ContractResolver from the passed Reader // newIncomingContestResolverFromReader attempts to decode an encoded ContractResolver
// instance, returning an active ContractResolver instance. // from the passed Reader instance, returning an active ContractResolver
// // instance.
// NOTE: Part of the ContractResolver interface. func newIncomingContestResolverFromReader(r io.Reader, resCfg ResolverConfig) (
func (h *htlcIncomingContestResolver) Decode(r io.Reader) error { *htlcIncomingContestResolver, error) {
h := &htlcIncomingContestResolver{}
// We'll first read the one field unique to this resolver. // We'll first read the one field unique to this resolver.
if err := binary.Read(r, endian, &h.htlcExpiry); err != nil { if err := binary.Read(r, endian, &h.htlcExpiry); err != nil {
return err return nil, err
} }
// Then we'll decode our internal resolver. // Then we'll decode our internal resolver.
return h.htlcSuccessResolver.Decode(r) successResolver, err := newSuccessResolverFromReader(r, resCfg)
} if err != nil {
return nil, err
}
h.htlcSuccessResolver = *successResolver
// AttachResolverKit should be called once a resolved is successfully decoded return h, nil
// from its stored format. This struct delivers a generic tool kit that
// resolvers need to complete their duty.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcIncomingContestResolver) AttachResolverKit(r ResolverKit) {
h.ResolverKit = r
} }
// A compile time assertion to ensure htlcIncomingContestResolver meets the // A compile time assertion to ensure htlcIncomingContestResolver meets the

@ -199,17 +199,18 @@ func newIncomingResolverTestContext(t *testing.T) *incomingResolverTestContext {
}, },
} }
cfg := ResolverConfig{
ChannelArbitratorConfig: chainCfg,
Checkpoint: func(_ ContractResolver) error {
checkPointChan <- struct{}{}
return nil
},
}
resolver := &htlcIncomingContestResolver{ resolver := &htlcIncomingContestResolver{
htlcSuccessResolver: htlcSuccessResolver{ htlcSuccessResolver: htlcSuccessResolver{
ResolverKit: ResolverKit{ contractResolverKit: *newContractResolverKit(cfg),
ChannelArbitratorConfig: chainCfg, htlcResolution: lnwallet.IncomingHtlcResolution{},
Checkpoint: func(_ ContractResolver) error { payHash: testResHash,
checkPointChan <- struct{}{}
return nil
},
},
htlcResolution: lnwallet.IncomingHtlcResolution{},
payHash: testResHash,
}, },
htlcExpiry: testHtlcExpiry, htlcExpiry: testHtlcExpiry,
} }

@ -5,6 +5,8 @@ import (
"io" "io"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
) )
// htlcOutgoingContestResolver is a ContractResolver that's able to resolve an // htlcOutgoingContestResolver is a ContractResolver that's able to resolve an
@ -18,6 +20,21 @@ type htlcOutgoingContestResolver struct {
htlcTimeoutResolver htlcTimeoutResolver
} }
// newOutgoingContestResolver instantiates a new outgoing contested htlc
// resolver.
func newOutgoingContestResolver(res lnwallet.OutgoingHtlcResolution,
broadcastHeight uint32, htlcIndex uint64, htlcAmt lnwire.MilliSatoshi,
resCfg ResolverConfig) *htlcOutgoingContestResolver {
timeout := newTimeoutResolver(
res, broadcastHeight, htlcIndex, htlcAmt, resCfg,
)
return &htlcOutgoingContestResolver{
htlcTimeoutResolver: *timeout,
}
}
// Resolve commences the resolution of this contract. As this contract hasn't // Resolve commences the resolution of this contract. As this contract hasn't
// yet timed out, we'll wait for one of two things to happen // yet timed out, we'll wait for one of two things to happen
// //
@ -130,7 +147,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) {
// claimed. // claimed.
return h.claimCleanUp(commitSpend) return h.claimCleanUp(commitSpend)
case <-h.Quit: case <-h.quit:
return nil, fmt.Errorf("resolver canceled") return nil, fmt.Errorf("resolver canceled")
} }
} }
@ -162,7 +179,7 @@ func (h *htlcOutgoingContestResolver) report() *ContractReport {
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (h *htlcOutgoingContestResolver) Stop() { func (h *htlcOutgoingContestResolver) Stop() {
close(h.Quit) close(h.quit)
} }
// IsResolved returns true if the stored state in the resolve is fully // IsResolved returns true if the stored state in the resolve is fully
@ -181,21 +198,19 @@ func (h *htlcOutgoingContestResolver) Encode(w io.Writer) error {
return h.htlcTimeoutResolver.Encode(w) return h.htlcTimeoutResolver.Encode(w)
} }
// Decode attempts to decode an encoded ContractResolver from the passed Reader // newOutgoingContestResolverFromReader attempts to decode an encoded ContractResolver
// instance, returning an active ContractResolver instance. // from the passed Reader instance, returning an active ContractResolver
// // instance.
// NOTE: Part of the ContractResolver interface. func newOutgoingContestResolverFromReader(r io.Reader, resCfg ResolverConfig) (
func (h *htlcOutgoingContestResolver) Decode(r io.Reader) error { *htlcOutgoingContestResolver, error) {
return h.htlcTimeoutResolver.Decode(r)
}
// AttachResolverKit should be called once a resolved is successfully decoded h := &htlcOutgoingContestResolver{}
// from its stored format. This struct delivers a generic tool kit that timeoutResolver, err := newTimeoutResolverFromReader(r, resCfg)
// resolvers need to complete their duty. if err != nil {
// return nil, err
// NOTE: Part of the ContractResolver interface. }
func (h *htlcOutgoingContestResolver) AttachResolverKit(r ResolverKit) { h.htlcTimeoutResolver = *timeoutResolver
h.ResolverKit = r return h, nil
} }
// A compile time assertion to ensure htlcOutgoingContestResolver meets the // A compile time assertion to ensure htlcOutgoingContestResolver meets the

@ -122,16 +122,18 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext {
}, },
} }
cfg := ResolverConfig{
ChannelArbitratorConfig: chainCfg,
Checkpoint: func(_ ContractResolver) error {
checkPointChan <- struct{}{}
return nil
},
}
resolver := &htlcOutgoingContestResolver{ resolver := &htlcOutgoingContestResolver{
htlcTimeoutResolver: htlcTimeoutResolver{ htlcTimeoutResolver: htlcTimeoutResolver{
ResolverKit: ResolverKit{ contractResolverKit: *newContractResolverKit(cfg),
ChannelArbitratorConfig: chainCfg, htlcResolution: outgoingRes,
Checkpoint: func(_ ContractResolver) error {
checkPointChan <- struct{}{}
return nil
},
},
htlcResolution: outgoingRes,
}, },
} }

@ -52,7 +52,22 @@ type htlcSuccessResolver struct {
// account any fees that may have to be paid if it goes on chain. // account any fees that may have to be paid if it goes on chain.
htlcAmt lnwire.MilliSatoshi htlcAmt lnwire.MilliSatoshi
ResolverKit contractResolverKit
}
// newSuccessResolver instanties a new htlc success resolver.
func newSuccessResolver(res lnwallet.IncomingHtlcResolution,
broadcastHeight uint32, payHash lntypes.Hash,
htlcAmt lnwire.MilliSatoshi,
resCfg ResolverConfig) *htlcSuccessResolver {
return &htlcSuccessResolver{
contractResolverKit: *newContractResolverKit(resCfg),
htlcResolution: res,
broadcastHeight: broadcastHeight,
payHash: payHash,
htlcAmt: htlcAmt,
}
} }
// ResolverKey returns an identifier which should be globally unique for this // ResolverKey returns an identifier which should be globally unique for this
@ -173,7 +188,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) {
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
case <-h.Quit: case <-h.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -238,7 +253,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) {
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
case <-h.Quit: case <-h.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -251,7 +266,7 @@ func (h *htlcSuccessResolver) Resolve() (ContractResolver, error) {
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) Stop() { func (h *htlcSuccessResolver) Stop() {
close(h.Quit) close(h.quit)
} }
// IsResolved returns true if the stored state in the resolve is fully // IsResolved returns true if the stored state in the resolve is fully
@ -290,41 +305,37 @@ func (h *htlcSuccessResolver) Encode(w io.Writer) error {
return nil return nil
} }
// Decode attempts to decode an encoded ContractResolver from the passed Reader // newSuccessResolverFromReader attempts to decode an encoded ContractResolver
// instance, returning an active ContractResolver instance. // from the passed Reader instance, returning an active ContractResolver
// // instance.
// NOTE: Part of the ContractResolver interface. func newSuccessResolverFromReader(r io.Reader, resCfg ResolverConfig) (
func (h *htlcSuccessResolver) Decode(r io.Reader) error { *htlcSuccessResolver, error) {
h := &htlcSuccessResolver{
contractResolverKit: *newContractResolverKit(resCfg),
}
// First we'll decode our inner HTLC resolution. // First we'll decode our inner HTLC resolution.
if err := decodeIncomingResolution(r, &h.htlcResolution); err != nil { if err := decodeIncomingResolution(r, &h.htlcResolution); err != nil {
return err return nil, err
} }
// Next, we'll read all the fields that are specified to the contract // Next, we'll read all the fields that are specified to the contract
// resolver. // resolver.
if err := binary.Read(r, endian, &h.outputIncubating); err != nil { if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &h.resolved); err != nil { if err := binary.Read(r, endian, &h.resolved); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return err return nil, err
} }
if _, err := io.ReadFull(r, h.payHash[:]); err != nil { if _, err := io.ReadFull(r, h.payHash[:]); err != nil {
return err return nil, err
} }
return nil return h, nil
}
// AttachResolverKit should be called once a resolved is successfully decoded
// from its stored format. This struct delivers a generic tool kit that
// resolvers need to complete their duty.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcSuccessResolver) AttachResolverKit(r ResolverKit) {
h.ResolverKit = r
} }
// A compile time assertion to ensure htlcSuccessResolver meets the // A compile time assertion to ensure htlcSuccessResolver meets the

@ -48,7 +48,22 @@ type htlcTimeoutResolver struct {
// account any fees that may have to be paid if it goes on chain. // account any fees that may have to be paid if it goes on chain.
htlcAmt lnwire.MilliSatoshi htlcAmt lnwire.MilliSatoshi
ResolverKit contractResolverKit
}
// newTimeoutResolver instantiates a new timeout htlc resolver.
func newTimeoutResolver(res lnwallet.OutgoingHtlcResolution,
broadcastHeight uint32, htlcIndex uint64,
htlcAmt lnwire.MilliSatoshi,
resCfg ResolverConfig) *htlcTimeoutResolver {
return &htlcTimeoutResolver{
contractResolverKit: *newContractResolverKit(resCfg),
htlcResolution: res,
broadcastHeight: broadcastHeight,
htlcIndex: htlcIndex,
htlcAmt: htlcAmt,
}
} }
// ResolverKey returns an identifier which should be globally unique for this // ResolverKey returns an identifier which should be globally unique for this
@ -274,7 +289,7 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) {
return errResolverShuttingDown return errResolverShuttingDown
} }
case <-h.Quit: case <-h.quit:
return errResolverShuttingDown return errResolverShuttingDown
} }
@ -312,7 +327,7 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) {
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
case <-h.Quit: case <-h.quit:
return nil, errResolverShuttingDown return nil, errResolverShuttingDown
} }
@ -365,7 +380,7 @@ func (h *htlcTimeoutResolver) Resolve() (ContractResolver, error) {
// //
// NOTE: Part of the ContractResolver interface. // NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) Stop() { func (h *htlcTimeoutResolver) Stop() {
close(h.Quit) close(h.quit)
} }
// IsResolved returns true if the stored state in the resolve is fully // IsResolved returns true if the stored state in the resolve is fully
@ -406,43 +421,39 @@ func (h *htlcTimeoutResolver) Encode(w io.Writer) error {
return nil return nil
} }
// Decode attempts to decode an encoded ContractResolver from the passed Reader // newTimeoutResolverFromReader attempts to decode an encoded ContractResolver
// instance, returning an active ContractResolver instance. // from the passed Reader instance, returning an active ContractResolver
// // instance.
// NOTE: Part of the ContractResolver interface. func newTimeoutResolverFromReader(r io.Reader, resCfg ResolverConfig) (
func (h *htlcTimeoutResolver) Decode(r io.Reader) error { *htlcTimeoutResolver, error) {
h := &htlcTimeoutResolver{
contractResolverKit: *newContractResolverKit(resCfg),
}
// First, we'll read out all the mandatory fields of the // First, we'll read out all the mandatory fields of the
// OutgoingHtlcResolution that we store. // OutgoingHtlcResolution that we store.
if err := decodeOutgoingResolution(r, &h.htlcResolution); err != nil { if err := decodeOutgoingResolution(r, &h.htlcResolution); err != nil {
return err return nil, err
} }
// With those fields read, we can now read back the fields that are // With those fields read, we can now read back the fields that are
// specific to the resolver itself. // specific to the resolver itself.
if err := binary.Read(r, endian, &h.outputIncubating); err != nil { if err := binary.Read(r, endian, &h.outputIncubating); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &h.resolved); err != nil { if err := binary.Read(r, endian, &h.resolved); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &h.broadcastHeight); err != nil { if err := binary.Read(r, endian, &h.broadcastHeight); err != nil {
return err return nil, err
} }
if err := binary.Read(r, endian, &h.htlcIndex); err != nil { if err := binary.Read(r, endian, &h.htlcIndex); err != nil {
return err return nil, err
} }
return nil return h, nil
}
// AttachResolverKit should be called once a resolved is successfully decoded
// from its stored format. This struct delivers a generic tool kit that
// resolvers need to complete their duty.
//
// NOTE: Part of the ContractResolver interface.
func (h *htlcTimeoutResolver) AttachResolverKit(r ResolverKit) {
h.ResolverKit = r
} }
// A compile time assertion to ensure htlcTimeoutResolver meets the // A compile time assertion to ensure htlcTimeoutResolver meets the

@ -237,15 +237,18 @@ func TestHtlcTimeoutResolver(t *testing.T) {
}, },
} }
resolver := &htlcTimeoutResolver{ cfg := ResolverConfig{
ResolverKit: ResolverKit{ ChannelArbitratorConfig: chainCfg,
ChannelArbitratorConfig: chainCfg, Checkpoint: func(_ ContractResolver) error {
Checkpoint: func(_ ContractResolver) error { checkPointChan <- struct{}{}
checkPointChan <- struct{}{} return nil
return nil
},
}, },
} }
resolver := &htlcTimeoutResolver{
contractResolverKit: *newContractResolverKit(
cfg,
),
}
resolver.htlcResolution.SweepSignDesc = *fakeSignDesc resolver.htlcResolution.SweepSignDesc = *fakeSignDesc
// If the test case needs the remote commitment to be // If the test case needs the remote commitment to be