diff --git a/lnwire/features.go b/lnwire/features.go new file mode 100644 index 00000000..e53d50ab --- /dev/null +++ b/lnwire/features.go @@ -0,0 +1,283 @@ +package lnwire + +import ( + "encoding/binary" + "fmt" + "github.com/go-errors/errors" + "io" + "math" +) + +// featureFlag represent the status of the feature optional/required and needed +// to allow future incompatible changes, or backward compatible changes. +type featureFlag uint8 + +func (f featureFlag) String() string { + switch f { + case OptionalFlag: + return "optional" + case RequiredFlag: + return "required" + default: + return "" + } +} + +// featureName represent the name of the feature and needed in order to have the +// compile errors if we specify wrong feature name. +type featureName string + +// FeaturesMap is the map which stores the correspondence between feature name +// and its index within feature vector. +// +// NOTE: Index within feature vector and actual binary position of feature +// are different things) +type FeaturesMap map[featureName]int + +// indexToFlag is the map which stores the correspondence between feature +// position and its flag. +type indexToFlag map[int]featureFlag + +const ( + // OptionalFlag represent the feature which we already have but it isn't + // required yet, and if remote peer doesn't have this feature we may + // turn it off without disconnecting with peer. + OptionalFlag featureFlag = 2 // 0b10 + + // RequiredFlag represent the features which is required for proper + // peer interaction, we disconnect with peer if it doesn't have this + // particular feature. + RequiredFlag featureFlag = 1 // 0b01 + + // flagMask is a mask which is needed to extract feature flag value. + flagMask = 3 // 0b11 + + // flagBitsSize represent the size of the feature flag in bits. For + // more information read the init message specification. + flagBitsSize = 2 + + // maxAllowedSize is a maximum allowed size of feature vector. + // NOTE: Within the protocol, the maximum allowed message size is 65535 + // bytes. Adding the overhead from the crypto protocol (the 2-byte packet + // length and 16-byte MAC), we arrive at 65569 bytes. Accounting for the + // overhead within the feature message to signal the type of the message, + // that leaves 65567 bytes for the init message itself. Next, we reserve + // 4-bytes to encode the lengths of both the local and global feature + // vectors, so 65563 for the global and local features. Knocking off one + // byte for the sake of the calculation, that leads to a max allowed + // size of 32781 bytes for each feature vector, or 131124 different + // features. + maxAllowedSize = 32781 +) + +// FeatureVector represents the global/local feature vector. With this structure +// you may set/get the feature by name and compare feature vector with remote +// one. +type FeatureVector struct { + features FeaturesMap // name -> index + flags indexToFlag // index -> flag +} + +// NewFeatureVector creates new instance of feature vector. +func NewFeatureVector(features FeaturesMap) *FeatureVector { + return &FeatureVector{ + features: features, + flags: make(indexToFlag), + } +} + +// SetFeatureFlag assign flag to the feature. +func (f *FeatureVector) SetFeatureFlag(name featureName, flag featureFlag) error { + position, ok := f.features[name] + if !ok { + return errors.Errorf("can't find feature with name: %v", name) + } + + f.flags[position] = flag + return nil +} + +// SerializedSize returns the number of bytes which is needed to represent feature +// vector in byte format. +func (f *FeatureVector) SerializedSize() uint16 { + return uint16(math.Ceil(float64(flagBitsSize*len(f.flags)) / 8)) +} + +// String returns the feature vector description. +func (f *FeatureVector) String() string { + var description string + for name, index := range f.features { + if flag, ok := f.flags[index]; ok { + description += fmt.Sprintf("%s: %s\n", name, flag) + } + } + + if description == "" { + description = "" + } + + return "\n" + description +} + +// NewFeatureVectorFromReader decodes the feature vector from binary +// representation and creates the instance of it. +// Every feature decoded as 2 bits where odd bit determine whether the feature +// is "optional" and even bit told us whether the feature is "required". The +// even/odd semantic allows future incompatible changes, or backward compatible +// changes. Bits generally assigned in pairs, so that optional features can +// later become compulsory. +func NewFeatureVectorFromReader(r io.Reader) (*FeatureVector, error) { + f := &FeatureVector{ + flags: make(indexToFlag), + } + + getFlag := func(data []byte, position int) featureFlag { + byteNumber := uint(position / 8) + bitNumber := uint(position % 8) + + return featureFlag((data[byteNumber] >> bitNumber) & flagMask) + } + + // Read the length of the feature vector. + var l [2]byte + if _, err := r.Read(l[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint16(l[:]) + + // Read the feature vector data. + data := make([]byte, length) + if _, err := r.Read(data); err != nil { + return nil, err + } + + // Initialize feature vector. + bitsNumber := len(data) * 8 + for position := 0; position <= bitsNumber-flagBitsSize; position += flagBitsSize { + flag := getFlag(data, position) + switch flag { + case OptionalFlag, RequiredFlag: + // Every feature/flag takes 2 bits, so in order to get + // the feature/flag index we should divide position + // on 2. + index := position / flagBitsSize + f.flags[index] = flag + default: + continue + } + } + + return f, nil +} + +// Encode encodes the features vector into bytes representation, every +// feature encoded as 2 bits where odd bit determine whether the feature is +// "optional" and even bit told us whether the feature is "required". The +// even/odd semantic allows future incompatible changes, or backward compatible +// changes. Bits generally assigned in pairs, so that optional features can +// later become compulsory. +func (f *FeatureVector) Encode(w io.Writer) error { + setFlag := func(data []byte, position int, flag featureFlag) { + byteNumber := uint(position / 8) + bitNumber := uint(position % 8) + + data[byteNumber] |= (byte(flag) << bitNumber) + } + + // Write length of feature vector. + var l [2]byte + length := f.SerializedSize() + binary.BigEndian.PutUint16(l[:], length) + if _, err := w.Write(l[:]); err != nil { + return err + } + + // Generate the data and write it. + data := make([]byte, length) + for index, flag := range f.flags { + // Every feature takes 2 bits, so in order to get the + // feature bits position we should multiply index by 2. + position := index * flagBitsSize + setFlag(data, position, flag) + } + + if _, err := w.Write(data); err != nil { + return err + } + + return nil +} + +// Compare checks that features are compatible and returns the features which +// were present in both remote and local feature vectors. If remote/local node +// doesn't have the feature and local/remote node require it than such vectors +// are incompatible. +func (local *FeatureVector) Compare(remote *FeatureVector) (*SharedFeatures, + error) { + shared := NewSharedFeatures(local.features) + + for index, flag := range local.flags { + if _, exist := remote.flags[index]; !exist { + switch flag { + case RequiredFlag: + return nil, errors.New("Remote node hasn't " + + "locally required feature") + case OptionalFlag: + // If feature is optional and remote side + // haven't it than it might be safely disabled. + continue + } + } + + // If feature exists on both sides than such feature might be + // considered as active. + shared.flags[index] = flag + } + + for index, flag := range remote.flags { + if _, exist := local.flags[index]; !exist { + switch flag { + case RequiredFlag: + return nil, errors.New("Local node hasn't " + + "locally required feature") + case OptionalFlag: + // If feature is optional and local side + // haven't it than it might be safely disabled. + continue + } + } + + // If feature exists on both sides than such feature might be + // considered as active. + shared.flags[index] = flag + } + + return shared, nil +} + +// SharedFeatures is a product of comparison of two features vector +// which consist of features which are present in both local and remote +// features vectors. +type SharedFeatures struct { + *FeatureVector +} + +// NewSharedFeatures creates new shared features instance. +func NewSharedFeatures(features FeaturesMap) *SharedFeatures { + return &SharedFeatures{NewFeatureVector(features)} +} + +// IsActive checks is feature active or not, it might be disabled during +// comparision with remote feature vector if it was optional and +// remote peer doesn't support it. +func (f *SharedFeatures) IsActive(name featureName) bool { + position, ok := f.features[name] + if !ok { + // If we even have no such feature in feature map, than it + // can't be active in any circumstances. + return false + } + + _, exist := f.flags[position] + return exist +} diff --git a/lnwire/features_test.go b/lnwire/features_test.go new file mode 100644 index 00000000..86a9eb83 --- /dev/null +++ b/lnwire/features_test.go @@ -0,0 +1,119 @@ +package lnwire + +import ( + "bytes" + "reflect" + "testing" +) + +// TestFeaturesRemoteRequireError checks that we throw an error if remote peer +// has required feature which we don't support. +func TestFeaturesRemoteRequireError(t *testing.T) { + var ( + first featureName = "first" + second featureName = "second" + ) + + var localFeaturesMap = FeaturesMap{ + first: 0, + } + + var remoteFeaturesMap = FeaturesMap{ + first: 0, + second: 1, + } + + localFeatures := NewFeatureVector(localFeaturesMap) + localFeatures.SetFeatureFlag(first, OptionalFlag) + + remoteFeatures := NewFeatureVector(remoteFeaturesMap) + remoteFeatures.SetFeatureFlag(first, RequiredFlag) + remoteFeatures.SetFeatureFlag(second, RequiredFlag) + + if _, err := localFeatures.Compare(remoteFeatures); err == nil { + t.Fatal("error wasn't received") + } +} + +// TestFeaturesLocalRequireError checks that we throw an error if local peer has +// required feature which remote peer don't support. +func TestFeaturesLocalRequireError(t *testing.T) { + var ( + first featureName = "first" + second featureName = "second" + ) + + var localFeaturesMap = FeaturesMap{ + first: 0, + second: 1, + } + + var remoteFeaturesMap = FeaturesMap{ + first: 0, + } + + localFeatures := NewFeatureVector(localFeaturesMap) + localFeatures.SetFeatureFlag(first, OptionalFlag) + localFeatures.SetFeatureFlag(second, RequiredFlag) + + remoteFeatures := NewFeatureVector(remoteFeaturesMap) + remoteFeatures.SetFeatureFlag(first, RequiredFlag) + + if _, err := localFeatures.Compare(remoteFeatures); err == nil { + t.Fatal("error wasn't received") + } +} + +// TestOptionalFeature checks that if remote peer don't have the feature but +// on our side this feature is optional than we mark this feature as disabled. +func TestOptionalFeature(t *testing.T) { + var first featureName = "first" + + var localFeaturesMap = FeaturesMap{ + first: 0, + } + + localFeatures := NewFeatureVector(localFeaturesMap) + localFeatures.SetFeatureFlag(first, OptionalFlag) + + remoteFeatures := NewFeatureVector(FeaturesMap{}) + + shared, err := localFeatures.Compare(remoteFeatures) + if err != nil { + t.Fatalf("error while feature vector compare: %v", err) + } + + if shared.IsActive(first) { + t.Fatal("locally feature was set but remote peer notified us" + + " that it don't have it") + } +} + +// TestDecodeEncodeFeaturesVector checks that feature vector might be +// successfully encoded and decoded. +func TestDecodeEncodeFeaturesVector(t *testing.T) { + var first featureName = "first" + + var localFeaturesMap = FeaturesMap{ + first: 0, + } + + f := NewFeatureVector(localFeaturesMap) + f.SetFeatureFlag(first, OptionalFlag) + + var b bytes.Buffer + if err := f.Encode(&b); err != nil { + t.Fatalf("error while encoding feature vector: %v", err) + } + + nf, err := NewFeatureVectorFromReader(&b) + if err != nil { + t.Fatalf("error while decoding feature vector: %v", err) + } + + // Assert equality of the two instances. + if !reflect.DeepEqual(f.flags, nf.flags) { + t.Fatalf("encode/decode feature vector don't match %v vs "+ + "%v", f.String(), nf.String()) + } +} diff --git a/lnwire/init_message.go b/lnwire/init_message.go new file mode 100644 index 00000000..1801229c --- /dev/null +++ b/lnwire/init_message.go @@ -0,0 +1,100 @@ +package lnwire + +import ( + "io" + "github.com/go-errors/errors" +) + +// Init is the first message reveals the features supported or required by this +// node. Nodes wait for receipt of the other's features to simplify error +// diagnosis where features are incompatible. Each node MUST wait to receive +// init before sending any other messages. +type Init struct { + // GlobalFeatures is feature vector which affects HTLCs and thus are + // also advertised to other nodes. + GlobalFeatures *FeatureVector + + // LocalFeatures is feature vector which only affect the protocol + // between two nodes. + LocalFeatures *FeatureVector +} + +// NewInitMessage creates new instance of init message object. +func NewInitMessage(gf, lf *FeatureVector) *Init { + return &Init{ + GlobalFeatures: gf, + LocalFeatures: lf, + } +} + +// Decode deserializes a serialized Init message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (msg *Init) Decode(r io.Reader, pver uint32) error { + // LocalFeatures(~) + // GlobalFeatures(~) + err := readElements(r, + &msg.LocalFeatures, + &msg.GlobalFeatures, + ) + if err != nil { + return err + } + + return nil +} + +// A compile time check to ensure Init implements the lnwire.Message +// interface. +var _ Message = (*Init)(nil) + +// Encode serializes the target Init into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (msg *Init) Encode(w io.Writer, pver uint32) error { + err := writeElements(w, + msg.LocalFeatures, + msg.GlobalFeatures, + ) + if err != nil { + return err + } + + return nil +} + +// Command returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (msg *Init) Command() uint32 { + return CmdInit +} + +// MaxPayloadLength returns the maximum allowed payload size for a Init +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (msg *Init) MaxPayloadLength(uint32) uint32 { + return 2 + maxAllowedSize + 2 + maxAllowedSize +} + +// Validate performs any necessary sanity checks to ensure all fields present +// on the Init are valid. +// +// This is part of the lnwire.Message interface. +func (msg *Init) Validate() error { + if msg.GlobalFeatures.SerializedSize() > maxAllowedSize { + return errors.Errorf("global feature vector exceed max allowed "+ + "size %v", maxAllowedSize) + } + + if msg.LocalFeatures.SerializedSize() > maxAllowedSize { + return errors.Errorf("local feature vector exceed max allowed "+ + "size %v", maxAllowedSize) + } + + return nil +} diff --git a/lnwire/init_test.go b/lnwire/init_test.go new file mode 100644 index 00000000..cf2b519f --- /dev/null +++ b/lnwire/init_test.go @@ -0,0 +1,48 @@ +package lnwire + +import ( + "bytes" + "reflect" + "testing" +) + +func TestInitEncodeDecode(t *testing.T) { + fm := FeaturesMap{ + "somefeature": 0, + } + + gf := NewFeatureVector(fm) + gf.SetFeatureFlag("somefeature", OptionalFlag) + + lf := NewFeatureVector(fm) + lf.SetFeatureFlag("somefeature", OptionalFlag) + + init1 := &Init{ + GlobalFeatures: gf, + LocalFeatures: lf, + } + + // Next encode the init message into an empty bytes buffer. + var b bytes.Buffer + if err := init1.Encode(&b, 0); err != nil { + t.Fatalf("unable to encode init: %v", err) + } + + // Deserialize the encoded init message into a new empty struct. + init2 := &Init{} + if err := init2.Decode(&b, 0); err != nil { + t.Fatalf("unable to decode init: %v", err) + } + + // We not encode the feature map in feature vector, for that reason the + // init messages will differ. Initialize decoded feature map in + // order to use deep equal function. + init2.GlobalFeatures.features = fm + init2.LocalFeatures.features = fm + + // Assert equality of the two instances. + if !reflect.DeepEqual(init1, init2) { + t.Fatalf("encode/decode init messages don't match %#v vs %#v", + init1, init2) + } +} diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 575e8d99..b0acf17f 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -282,6 +282,10 @@ func writeElement(w io.Writer, element interface{}) error { if _, err := w.Write(idx[:]); err != nil { return err } + case *FeatureVector: + if err := e.Encode(w); err != nil { + return err + } case *wire.OutPoint: // TODO(roasbeef): consolidate with above // First write out the previous txid. @@ -455,6 +459,14 @@ func readElement(r io.Reader, element interface{}) error { return err } *e = pubKey + case **FeatureVector: + f, err := NewFeatureVectorFromReader(r) + if err != nil { + return err + } + + *e = f + case *[]uint64: var numItems uint16 if err := readElement(r, &numItems); err != nil { diff --git a/lnwire/message.go b/lnwire/message.go index 79c71334..4251c75d 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -22,6 +22,8 @@ const MaxMessagePayload = 1024 * 1024 * 32 // 32MB // Commands used in lightning message headers which detail the type of message. const ( + CmdInit = uint32(1) + // Commands for opening a channel funded by one party (single funder). CmdSingleFundingRequest = uint32(100) CmdSingleFundingResponse = uint32(110) @@ -88,6 +90,8 @@ func makeEmptyMessage(command uint32) (Message, error) { var msg Message switch command { + case CmdInit: + msg = &Init{} case CmdSingleFundingRequest: msg = &SingleFundingRequest{} case CmdSingleFundingResponse: