watchtower/multi: switch over to wtpolicy

migrate to using wtpolicy.Policy in wtwire messages and wtserver
This commit is contained in:
Conner Fromknecht 2019-01-10 15:35:11 -08:00
parent c315d74347
commit b746bf86c2
No known key found for this signature in database
GPG Key ID: E7D737B67FA592C7
10 changed files with 113 additions and 92 deletions

@ -17,12 +17,6 @@ import (
) )
const ( const (
// MinVersion is the minimum blob version supported by this package.
MinVersion = 0
// MaxVersion is the maximumm blob version supported by this package.
MaxVersion = 0
// NonceSize is the length of a chacha20poly1305 nonce, 24 bytes. // NonceSize is the length of a chacha20poly1305 nonce, 24 bytes.
NonceSize = chacha20poly1305.NonceSizeX NonceSize = chacha20poly1305.NonceSizeX
@ -53,14 +47,14 @@ const (
// nonce: 24 bytes // nonce: 24 bytes
// enciphered plaintext: n bytes // enciphered plaintext: n bytes
// MAC: 16 bytes // MAC: 16 bytes
func Size(ver uint16) int { func Size(blobType Type) int {
return NonceSize + PlaintextSize(ver) + CiphertextExpansion return NonceSize + PlaintextSize(blobType) + CiphertextExpansion
} }
// PlaintextSize returns the size of the encoded-but-unencrypted blob in bytes. // PlaintextSize returns the size of the encoded-but-unencrypted blob in bytes.
func PlaintextSize(ver uint16) int { func PlaintextSize(blobType Type) int {
switch ver { switch {
case 0: case blobType.Has(FlagCommitOutputs):
return V0PlaintextSize return V0PlaintextSize
default: default:
return 0 return 0
@ -71,9 +65,9 @@ var (
// byteOrder specifies a big-endian encoding of all integer values. // byteOrder specifies a big-endian encoding of all integer values.
byteOrder = binary.BigEndian byteOrder = binary.BigEndian
// ErrUnknownBlobVersion signals that we don't understand the requested // ErrUnknownBlobType signals that we don't understand the requested
// blob encoding scheme. // blob encoding scheme.
ErrUnknownBlobVersion = errors.New("unknown blob version") ErrUnknownBlobType = errors.New("unknown blob type")
// ErrCiphertextTooSmall is a decryption error signaling that the // ErrCiphertextTooSmall is a decryption error signaling that the
// ciphertext is smaller than the ciphertext expansion factor. // ciphertext is smaller than the ciphertext expansion factor.
@ -229,7 +223,7 @@ func (b *JusticeKit) CommitToRemoteWitnessStack() ([][]byte, error) {
// //
// NOTE: It is the caller's responsibility to ensure that this method is only // NOTE: It is the caller's responsibility to ensure that this method is only
// called once for a given (nonce, key) pair. // called once for a given (nonce, key) pair.
func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) { func (b *JusticeKit) Encrypt(key []byte, blobType Type) ([]byte, error) {
// Fail if the nonce is not 32-bytes. // Fail if the nonce is not 32-bytes.
if len(key) != KeySize { if len(key) != KeySize {
return nil, ErrKeySize return nil, ErrKeySize
@ -238,7 +232,7 @@ func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) {
// Encode the plaintext using the provided version, to obtain the // Encode the plaintext using the provided version, to obtain the
// plaintext bytes. // plaintext bytes.
var ptxtBuf bytes.Buffer var ptxtBuf bytes.Buffer
err := b.encode(&ptxtBuf, version) err := b.encode(&ptxtBuf, blobType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -252,7 +246,7 @@ func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) {
// Allocate the ciphertext, which will contain the nonce, encrypted // Allocate the ciphertext, which will contain the nonce, encrypted
// plaintext and MAC. // plaintext and MAC.
plaintext := ptxtBuf.Bytes() plaintext := ptxtBuf.Bytes()
ciphertext := make([]byte, Size(version)) ciphertext := make([]byte, Size(blobType))
// Generate a random 24-byte nonce in the ciphertext's prefix. // Generate a random 24-byte nonce in the ciphertext's prefix.
nonce := ciphertext[:NonceSize] nonce := ciphertext[:NonceSize]
@ -270,7 +264,7 @@ func (b *JusticeKit) Encrypt(key []byte, version uint16) ([]byte, error) {
// Decrypt unenciphers a blob of justice by decrypting the ciphertext using // Decrypt unenciphers a blob of justice by decrypting the ciphertext using
// chacha20poly1305 with the chosen (nonce, key) pair. The internal plaintext is // chacha20poly1305 with the chosen (nonce, key) pair. The internal plaintext is
// then deserialized using the given encoding version. // then deserialized using the given encoding version.
func Decrypt(key, ciphertext []byte, version uint16) (*JusticeKit, error) { func Decrypt(key, ciphertext []byte, blobType Type) (*JusticeKit, error) {
switch { switch {
// Fail if the blob's overall length is less than required for the nonce // Fail if the blob's overall length is less than required for the nonce
@ -305,7 +299,7 @@ func Decrypt(key, ciphertext []byte, version uint16) (*JusticeKit, error) {
// If decryption succeeded, we will then decode the plaintext bytes // If decryption succeeded, we will then decode the plaintext bytes
// using the specified blob version. // using the specified blob version.
boj := &JusticeKit{} boj := &JusticeKit{}
err = boj.decode(bytes.NewReader(plaintext), version) err = boj.decode(bytes.NewReader(plaintext), blobType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -315,23 +309,23 @@ func Decrypt(key, ciphertext []byte, version uint16) (*JusticeKit, error) {
// encode serializes the JusticeKit according to the version, returning an // encode serializes the JusticeKit according to the version, returning an
// error if the version is unknown. // error if the version is unknown.
func (b *JusticeKit) encode(w io.Writer, ver uint16) error { func (b *JusticeKit) encode(w io.Writer, blobType Type) error {
switch ver { switch {
case 0: case blobType.Has(FlagCommitOutputs):
return b.encodeV0(w) return b.encodeV0(w)
default: default:
return ErrUnknownBlobVersion return ErrUnknownBlobType
} }
} }
// decode deserializes the JusticeKit according to the version, returning an // decode deserializes the JusticeKit according to the version, returning an
// error if the version is unknown. // error if the version is unknown.
func (b *JusticeKit) decode(r io.Reader, ver uint16) error { func (b *JusticeKit) decode(r io.Reader, blobType Type) error {
switch ver { switch {
case 0: case blobType.Has(FlagCommitOutputs):
return b.decodeV0(r) return b.decodeV0(r)
default: default:
return ErrUnknownBlobVersion return ErrUnknownBlobType
} }
} }

@ -38,8 +38,8 @@ func makeAddr(size int) []byte {
type descriptorTest struct { type descriptorTest struct {
name string name string
encVersion uint16 encVersion blob.Type
decVersion uint16 decVersion blob.Type
sweepAddr []byte sweepAddr []byte
revPubKey blob.PubKey revPubKey blob.PubKey
delayPubKey blob.PubKey delayPubKey blob.PubKey
@ -52,11 +52,15 @@ type descriptorTest struct {
decErr error decErr error
} }
var rewardAndCommitType = blob.TypeFromFlags(
blob.FlagReward, blob.FlagCommitOutputs,
)
var descriptorTests = []descriptorTest{ var descriptorTests = []descriptorTest{
{ {
name: "to-local only", name: "to-local only",
encVersion: 0, encVersion: blob.TypeDefault,
decVersion: 0, decVersion: blob.TypeDefault,
sweepAddr: makeAddr(22), sweepAddr: makeAddr(22),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -65,8 +69,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "to-local and p2wkh", name: "to-local and p2wkh",
encVersion: 0, encVersion: rewardAndCommitType,
decVersion: 0, decVersion: rewardAndCommitType,
sweepAddr: makeAddr(22), sweepAddr: makeAddr(22),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -78,30 +82,30 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "unknown encrypt version", name: "unknown encrypt version",
encVersion: 1, encVersion: 0,
decVersion: 0, decVersion: blob.TypeDefault,
sweepAddr: makeAddr(34), sweepAddr: makeAddr(34),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
csvDelay: 144, csvDelay: 144,
commitToLocalSig: makeSig(1), commitToLocalSig: makeSig(1),
encErr: blob.ErrUnknownBlobVersion, encErr: blob.ErrUnknownBlobType,
}, },
{ {
name: "unknown decrypt version", name: "unknown decrypt version",
encVersion: 0, encVersion: blob.TypeDefault,
decVersion: 1, decVersion: 0,
sweepAddr: makeAddr(34), sweepAddr: makeAddr(34),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
csvDelay: 144, csvDelay: 144,
commitToLocalSig: makeSig(1), commitToLocalSig: makeSig(1),
decErr: blob.ErrUnknownBlobVersion, decErr: blob.ErrUnknownBlobType,
}, },
{ {
name: "sweep addr length zero", name: "sweep addr length zero",
encVersion: 0, encVersion: blob.TypeDefault,
decVersion: 0, decVersion: blob.TypeDefault,
sweepAddr: makeAddr(0), sweepAddr: makeAddr(0),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -110,8 +114,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "sweep addr max size", name: "sweep addr max size",
encVersion: 0, encVersion: blob.TypeDefault,
decVersion: 0, decVersion: blob.TypeDefault,
sweepAddr: makeAddr(blob.MaxSweepAddrSize), sweepAddr: makeAddr(blob.MaxSweepAddrSize),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),
@ -120,8 +124,8 @@ var descriptorTests = []descriptorTest{
}, },
{ {
name: "sweep addr too long", name: "sweep addr too long",
encVersion: 0, encVersion: blob.TypeDefault,
decVersion: 0, decVersion: blob.TypeDefault,
sweepAddr: makeAddr(blob.MaxSweepAddrSize + 1), sweepAddr: makeAddr(blob.MaxSweepAddrSize + 1),
revPubKey: makePubKey(0), revPubKey: makePubKey(0),
delayPubKey: makePubKey(1), delayPubKey: makePubKey(1),

@ -19,6 +19,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/lookout"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
) )
const csvDelay uint32 = 144 const csvDelay uint32 = 144
@ -170,8 +171,10 @@ func TestJusticeDescriptor(t *testing.T) {
// parameters that should be used in constructing the justice // parameters that should be used in constructing the justice
// transaction. // transaction.
sessionInfo := &wtdb.SessionInfo{ sessionInfo := &wtdb.SessionInfo{
Policy: wtpolicy.Policy{
SweepFeeRate: 2000, SweepFeeRate: 2000,
RewardRate: 900000, RewardRate: 900000,
},
RewardAddress: makeAddrSlice(22), RewardAddress: makeAddrSlice(22),
} }

@ -210,7 +210,7 @@ func (l *Lookout) processEpoch(epoch *chainntnfs.BlockEpoch,
// sweep the breached commitment outputs. // sweep the breached commitment outputs.
justiceKit, err := blob.Decrypt( justiceKit, err := blob.Decrypt(
commitTxID[:], match.EncryptedBlob, commitTxID[:], match.EncryptedBlob,
match.SessionInfo.Version, match.SessionInfo.Policy.BlobType,
) )
if err != nil { if err != nil {
// If the decryption fails, this implies either that the // If the decryption fails, this implies either that the

@ -15,6 +15,7 @@ import (
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/lookout" "github.com/lightningnetwork/lnd/watchtower/lookout"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
) )
type mockPunisher struct { type mockPunisher struct {
@ -86,15 +87,25 @@ func TestLookoutBreachMatching(t *testing.T) {
t.Fatalf("unable to start watcher: %v", err) t.Fatalf("unable to start watcher: %v", err)
} }
rewardAndCommitType := blob.TypeFromFlags(
blob.FlagReward, blob.FlagCommitOutputs,
)
// Create two sessions, representing two distinct clients. // Create two sessions, representing two distinct clients.
sessionInfo1 := &wtdb.SessionInfo{ sessionInfo1 := &wtdb.SessionInfo{
ID: makeArray33(1), ID: makeArray33(1),
Policy: wtpolicy.Policy{
BlobType: rewardAndCommitType,
MaxUpdates: 10, MaxUpdates: 10,
},
RewardAddress: makeAddrSlice(22), RewardAddress: makeAddrSlice(22),
} }
sessionInfo2 := &wtdb.SessionInfo{ sessionInfo2 := &wtdb.SessionInfo{
ID: makeArray33(2), ID: makeArray33(2),
Policy: wtpolicy.Policy{
BlobType: rewardAndCommitType,
MaxUpdates: 10, MaxUpdates: 10,
},
RewardAddress: makeAddrSlice(22), RewardAddress: makeAddrSlice(22),
} }
@ -137,13 +148,13 @@ func TestLookoutBreachMatching(t *testing.T) {
} }
// Encrypt the first justice kit under the txid of the first txn. // Encrypt the first justice kit under the txid of the first txn.
encBlob1, err := blob1.Encrypt(hash1[:], 0) encBlob1, err := blob1.Encrypt(hash1[:], blob.FlagCommitOutputs.Type())
if err != nil { if err != nil {
t.Fatalf("unable to encrypt sweep detail 1: %v", err) t.Fatalf("unable to encrypt sweep detail 1: %v", err)
} }
// Encrypt the second justice kit under the txid of the second txn. // Encrypt the second justice kit under the txid of the second txn.
encBlob2, err := blob2.Encrypt(hash2[:], 0) encBlob2, err := blob2.Encrypt(hash2[:], blob.FlagCommitOutputs.Type())
if err != nil { if err != nil {
t.Fatalf("unable to encrypt sweep detail 2: %v", err) t.Fatalf("unable to encrypt sweep detail 2: %v", err)
} }

@ -4,7 +4,7 @@ import (
"errors" "errors"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/watchtower/wtpolicy"
) )
var ( var (
@ -49,12 +49,8 @@ type SessionInfo struct {
// ID is the remote public key of the watchtower client. // ID is the remote public key of the watchtower client.
ID SessionID ID SessionID
// Version specifies the plaintext blob encoding of all state updates. // Policy holds the negotiated session parameters.
Version uint16 Policy wtpolicy.Policy
// MaxUpdates is the total number of updates the client can send for
// this session.
MaxUpdates uint16
// LastApplied the sequence number of the last successful state update. // LastApplied the sequence number of the last successful state update.
LastApplied uint16 LastApplied uint16
@ -62,14 +58,6 @@ type SessionInfo struct {
// ClientLastApplied the last last-applied the client has echoed back. // ClientLastApplied the last last-applied the client has echoed back.
ClientLastApplied uint16 ClientLastApplied uint16
// RewardRate the fraction of the swept amount that goes to the tower,
// expressed in millionths of the swept balance.
RewardRate uint32
// SweepFeeRate is the agreed upon fee rate used to sign any sweep
// transactions.
SweepFeeRate lnwallet.SatPerKWeight
// RewardAddress the address that the tower's reward will be deposited // RewardAddress the address that the tower's reward will be deposited
// to if a sweep transaction confirms. // to if a sweep transaction confirms.
RewardAddress []byte RewardAddress []byte
@ -96,7 +84,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error {
return ErrLastAppliedReversion return ErrLastAppliedReversion
// Client update exceeds capacity of session. // Client update exceeds capacity of session.
case seqNum > s.MaxUpdates: case seqNum > s.Policy.MaxUpdates:
return ErrSessionConsumed return ErrSessionConsumed
// Client update does not match our expected next seqnum. // Client update does not match our expected next seqnum.
@ -117,7 +105,7 @@ func (s *SessionInfo) AcceptUpdateSequence(seqNum, lastApplied uint16) error {
func (s *SessionInfo) ComputeSweepOutputs(totalAmt btcutil.Amount, func (s *SessionInfo) ComputeSweepOutputs(totalAmt btcutil.Amount,
txVSize int64) (btcutil.Amount, btcutil.Amount, error) { txVSize int64) (btcutil.Amount, btcutil.Amount, error) {
txFee := s.SweepFeeRate.FeeForWeight(txVSize) txFee := s.Policy.SweepFeeRate.FeeForWeight(txVSize)
if txFee > totalAmt { if txFee > totalAmt {
return 0, 0, ErrFeeExceedsInputs return 0, 0, ErrFeeExceedsInputs
} }
@ -126,7 +114,8 @@ func (s *SessionInfo) ComputeSweepOutputs(totalAmt btcutil.Amount,
// Apply the reward rate to the remaining total, specified in millionths // Apply the reward rate to the remaining total, specified in millionths
// of the available balance. // of the available balance.
rewardAmt := (totalAmt*btcutil.Amount(s.RewardRate) + 999999) / 1000000 rewardRate := btcutil.Amount(s.Policy.RewardRate)
rewardAmt := (totalAmt*rewardRate + 999999) / 1000000
sweepAmt := totalAmt - rewardAmt sweepAmt := totalAmt - rewardAmt
// TODO(conner): check dustiness // TODO(conner): check dustiness

@ -14,6 +14,7 @@ import (
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
"github.com/lightningnetwork/lnd/watchtower/wtwire" "github.com/lightningnetwork/lnd/watchtower/wtwire"
) )
@ -246,14 +247,14 @@ func (s *Server) handleClient(peer Peer) {
log.Infof("Received CreateSession from %s, "+ log.Infof("Received CreateSession from %s, "+
"version=%d nupdates=%d rewardrate=%d "+ "version=%d nupdates=%d rewardrate=%d "+
"sweepfeerate=%d", id, msg.BlobVersion, "sweepfeerate=%d", id, msg.BlobType,
msg.MaxUpdates, msg.RewardRate, msg.MaxUpdates, msg.RewardRate,
msg.SweepFeeRate) msg.SweepFeeRate)
// Attempt to open a new session for this client. // Attempt to open a new session for this client.
err := s.handleCreateSession(peer, &id, msg) err := s.handleCreateSession(peer, &id, msg)
if err != nil { if err != nil {
log.Errorf("unable to handle CreateSession "+ log.Errorf("Unable to handle CreateSession "+
"from %s: %v", id, err) "from %s: %v", id, err)
} }
@ -327,7 +328,7 @@ func (s *Server) handleInit(localInit, remoteInit *wtwire.Init) error {
// session info is known about the session id. If an existing session is found, // session info is known about the session id. If an existing session is found,
// the reward address is returned in case the client lost our reply. // the reward address is returned in case the client lost our reply.
func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID, func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
init *wtwire.CreateSession) error { req *wtwire.CreateSession) error {
// TODO(conner): validate accept against policy // TODO(conner): validate accept against policy
@ -375,11 +376,13 @@ func (s *Server) handleCreateSession(peer Peer, id *wtdb.SessionID,
// address, and session id. // address, and session id.
info := wtdb.SessionInfo{ info := wtdb.SessionInfo{
ID: *id, ID: *id,
Version: init.BlobVersion,
MaxUpdates: init.MaxUpdates,
RewardRate: init.RewardRate,
SweepFeeRate: init.SweepFeeRate,
RewardAddress: rewardAddrBytes, RewardAddress: rewardAddrBytes,
Policy: wtpolicy.Policy{
BlobType: req.BlobType,
MaxUpdates: req.MaxUpdates,
RewardRate: req.RewardRate,
SweepFeeRate: req.SweepFeeRate,
},
} }
// Insert the session info into the watchtower's database. If // Insert the session info into the watchtower's database. If

@ -12,6 +12,7 @@ import (
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtdb"
"github.com/lightningnetwork/lnd/watchtower/wtserver" "github.com/lightningnetwork/lnd/watchtower/wtserver"
"github.com/lightningnetwork/lnd/watchtower/wtwire" "github.com/lightningnetwork/lnd/watchtower/wtwire"
@ -155,7 +156,7 @@ var createSessionTests = []createSessionTestCase{
lnwire.NewRawFeatureVector(), lnwire.NewRawFeatureVector(),
), ),
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 1000, MaxUpdates: 1000,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -258,7 +259,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 3, MaxUpdates: 3,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -287,7 +288,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -310,7 +311,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -337,7 +338,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -364,7 +365,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -393,7 +394,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 4, MaxUpdates: 4,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -421,7 +422,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 3, MaxUpdates: 3,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,
@ -450,7 +451,7 @@ var stateUpdateTests = []stateUpdateTestCase{
GlobalFeatures: lnwire.NewRawFeatureVector(), GlobalFeatures: lnwire.NewRawFeatureVector(),
}}, }},
createMsg: &wtwire.CreateSession{ createMsg: &wtwire.CreateSession{
BlobVersion: 0, BlobType: blob.TypeDefault,
MaxUpdates: 3, MaxUpdates: 3,
RewardRate: 0, RewardRate: 0,
SweepFeeRate: 1, SweepFeeRate: 1,

@ -4,6 +4,7 @@ import (
"io" "io"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/watchtower/blob"
) )
// CreateSession is sent from a client to tower when to negotiate a session, which // CreateSession is sent from a client to tower when to negotiate a session, which
@ -11,9 +12,9 @@ import (
// An update is consumed by uploading an encrypted blob that contains // An update is consumed by uploading an encrypted blob that contains
// information required to sweep a revoked commitment transaction. // information required to sweep a revoked commitment transaction.
type CreateSession struct { type CreateSession struct {
// BlobVersion specifies the blob format that must be used by all // BlobType specifies the blob format that must be used by all updates sent
// updates sent under the session key used to negotiate this session. // under the session key used to negotiate this session.
BlobVersion uint16 BlobType blob.Type
// MaxUpdates is the maximum number of updates the watchtower will honor // MaxUpdates is the maximum number of updates the watchtower will honor
// for this session. // for this session.
@ -41,7 +42,7 @@ var _ Message = (*CreateSession)(nil)
// This is part of the wtwire.Message interface. // This is part of the wtwire.Message interface.
func (m *CreateSession) Decode(r io.Reader, pver uint32) error { func (m *CreateSession) Decode(r io.Reader, pver uint32) error {
return ReadElements(r, return ReadElements(r,
&m.BlobVersion, &m.BlobType,
&m.MaxUpdates, &m.MaxUpdates,
&m.RewardRate, &m.RewardRate,
&m.SweepFeeRate, &m.SweepFeeRate,
@ -54,7 +55,7 @@ func (m *CreateSession) Decode(r io.Reader, pver uint32) error {
// This is part of the wtwire.Message interface. // This is part of the wtwire.Message interface.
func (m *CreateSession) Encode(w io.Writer, pver uint32) error { func (m *CreateSession) Encode(w io.Writer, pver uint32) error {
return WriteElements(w, return WriteElements(w,
m.BlobVersion, m.BlobType,
m.MaxUpdates, m.MaxUpdates,
m.RewardRate, m.RewardRate,
m.SweepFeeRate, m.SweepFeeRate,

@ -8,6 +8,7 @@ import (
"github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/btcec"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/watchtower/blob"
) )
// WriteElement is a one-stop shop to write the big endian representation of // WriteElement is a one-stop shop to write the big endian representation of
@ -30,6 +31,13 @@ func WriteElement(w io.Writer, element interface{}) error {
return err return err
} }
case blob.Type:
var b [2]byte
binary.BigEndian.PutUint16(b[:], uint16(e))
if _, err := w.Write(b[:]); err != nil {
return err
}
case uint32: case uint32:
var b [4]byte var b [4]byte
binary.BigEndian.PutUint32(b[:], e) binary.BigEndian.PutUint32(b[:], e)
@ -127,6 +135,13 @@ func ReadElement(r io.Reader, element interface{}) error {
} }
*e = binary.BigEndian.Uint16(b[:]) *e = binary.BigEndian.Uint16(b[:])
case *blob.Type:
var b [2]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return err
}
*e = blob.Type(binary.BigEndian.Uint16(b[:]))
case *uint32: case *uint32:
var b [4]byte var b [4]byte
if _, err := io.ReadFull(r, b[:]); err != nil { if _, err := io.ReadFull(r, b[:]); err != nil {