diff --git a/lnwire/features.go b/lnwire/features.go index 5584191b..250c041e 100644 --- a/lnwire/features.go +++ b/lnwire/features.go @@ -398,6 +398,20 @@ func (fv *FeatureVector) HasFeature(feature FeatureBit) bool { (fv.isFeatureBitPair(feature) && fv.IsSet(feature^1)) } +// RequiresFeature returns true if the referenced feature vector *requires* +// that the given required bit be set. This method can be used with both +// optional and required feature bits as a parameter. +func (fv *FeatureVector) RequiresFeature(feature FeatureBit) bool { + // If we weren't passed a required feature bit, then we'll flip the + // lowest bit to query for the required version of the feature. This + // lets callers pass in both the optional and required bits. + if !feature.IsRequired() { + feature ^= 1 + } + + return fv.IsSet(feature) +} + // UnknownRequiredFeatures returns a list of feature bits set in the vector // that are unknown and in an even bit position. Feature bits with an even // index must be known to a node receiving the feature vector in a message. diff --git a/lnwire/features_test.go b/lnwire/features_test.go index 1cff8f52..3eed2b1b 100644 --- a/lnwire/features_test.go +++ b/lnwire/features_test.go @@ -5,6 +5,8 @@ import ( "reflect" "sort" "testing" + + "github.com/stretchr/testify/require" ) var testFeatureNames = map[FeatureBit]string{ @@ -87,6 +89,7 @@ func TestFeatureVectorSetUnset(t *testing.T) { t.Errorf("Expectation failed in case %d, bit %d", i, j) break } + } for _, bit := range test.bits { @@ -95,6 +98,31 @@ func TestFeatureVectorSetUnset(t *testing.T) { } } +// TestFeatureVectorRequiresFeature tests that if a feature vector only +// includes a required feature bit (it's even), then the RequiresFeature method +// will return true for both that bit as well as it's optional counter party. +func TestFeatureVectorRequiresFeature(t *testing.T) { + t.Parallel() + + // Create a new feature vector with the features above, and set only + // the set of required bits. These will be all the even features + // referenced above. + fv := NewFeatureVector(nil, testFeatureNames) + fv.Set(0) + fv.Set(4) + + // Next we'll query for those exact bits, these should show up as being + // required. + require.True(t, fv.RequiresFeature(0)) + require.True(t, fv.RequiresFeature(4)) + + // If we query for the odd (optional) counter party to each of the + // features, the method should still return that the backing feature + // vector requires the feature to be set. + require.True(t, fv.RequiresFeature(1)) + require.True(t, fv.RequiresFeature(5)) +} + func TestFeatureVectorEncodeDecode(t *testing.T) { t.Parallel() @@ -277,7 +305,7 @@ func TestIsRequired(t *testing.T) { } // TestFeatures asserts that the Features() method on a FeatureVector properly -// returns the set of feature bits it stores internallly. +// returns the set of feature bits it stores internally. func TestFeatures(t *testing.T) { tests := []struct { name string