diff --git a/lnwire/features.go b/lnwire/features.go index f41ae607..f26403ae 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -94,6 +94,11 @@ const ( maxAllowedSize = 32764 ) +// IsRequired returns true if the feature bit is even, and false otherwise. +func (b FeatureBit) IsRequired() bool { + return b&0x01 == 0x00 +} + // Features is a mapping of known feature bits to a descriptive name. All known // feature bits must be assigned a name in this mapping, and feature bit pairs // must be assigned together for correct behavior. @@ -383,6 +388,15 @@ func (fv *FeatureVector) isFeatureBitPair(bit FeatureBit) bool { return known1 && known2 && name1 == name2 } +// Features returns the set of raw features contained in the feature vector. +func (fv *FeatureVector) Features() map[FeatureBit]struct{} { + fs := make(map[FeatureBit]struct{}, len(fv.RawFeatureVector.features)) + for b := range fv.RawFeatureVector.features { + fs[b] = struct{}{} + } + return fs +} + // Clone copies a feature vector, carrying over its feature bits. The feature // names are not copied. func (fv *FeatureVector) Clone() *FeatureVector { diff --git a/lnwire/features_test.go b/lnwire/features_test.go index cff76ec1..1cff8f52 100644 --- a/lnwire/features_test.go +++ b/lnwire/features_test.go @@ -260,3 +260,68 @@ func TestFeatureNames(t *testing.T) { } } } + +// TestIsRequired asserts that feature bits properly return their IsRequired +// status. We require that even features be required and odd features be +// optional. +func TestIsRequired(t *testing.T) { + optional := FeatureBit(1) + if optional.IsRequired() { + t.Fatalf("optional feature should not be required") + } + + required := FeatureBit(0) + if !required.IsRequired() { + t.Fatalf("required feature should be required") + } +} + +// TestFeatures asserts that the Features() method on a FeatureVector properly +// returns the set of feature bits it stores internallly. +func TestFeatures(t *testing.T) { + tests := []struct { + name string + exp map[FeatureBit]struct{} + }{ + { + name: "empty", + exp: map[FeatureBit]struct{}{}, + }, + { + name: "one", + exp: map[FeatureBit]struct{}{ + 5: {}, + }, + }, + { + name: "several", + exp: map[FeatureBit]struct{}{ + 0: {}, + 5: {}, + 23948: {}, + }, + }, + } + + toRawFV := func(set map[FeatureBit]struct{}) *RawFeatureVector { + var bits []FeatureBit + for bit := range set { + bits = append(bits, bit) + } + return NewRawFeatureVector(bits...) + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + fv := NewFeatureVector( + toRawFV(test.exp), Features, + ) + + if !reflect.DeepEqual(fv.Features(), test.exp) { + t.Fatalf("feature mismatch, want: %v, got: %v", + test.exp, fv.Features()) + } + }) + } +} diff --git a/rpcserver.go b/rpcserver.go index cf0e8bd4..33d29454 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -4620,6 +4620,20 @@ func (r *rpcServer) DecodePayReq(ctx context.Context, paymentAddr = payReq.PaymentAddr[:] } + // Convert any features on the payment request into a descriptive format + // for the rpc. + invFeatures := payReq.Features.Features() + features := make([]*lnrpc.Feature, 0, len(invFeatures)) + for bit := range invFeatures { + name := payReq.Features.Name(bit) + features = append(features, &lnrpc.Feature{ + Bit: uint32(bit), + Name: name, + IsRequired: bit.IsRequired(), + IsKnown: name != "unknown", + }) + } + dest := payReq.Destination.SerializeCompressed() return &lnrpc.PayReq{ Destination: hex.EncodeToString(dest), @@ -4634,6 +4648,7 @@ func (r *rpcServer) DecodePayReq(ctx context.Context, CltvExpiry: int64(payReq.MinFinalCLTVExpiry()), RouteHints: routeHints, PaymentAddr: paymentAddr, + Features: features, }, nil }