diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 2a4352e7..de0edaf2 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -76,6 +76,11 @@ func (a addressType) AddrLen() uint16 { // serialization. func WriteElement(w io.Writer, element interface{}) error { switch e := element.(type) { + case NodeAlias: + if _, err := w.Write(e[:]); err != nil { + return err + } + case ShortChanIDEncoding: var b [1]byte b[0] = uint8(e) @@ -429,6 +434,18 @@ func WriteElements(w io.Writer, elements ...interface{}) error { func ReadElement(r io.Reader, element interface{}) error { var err error switch e := element.(type) { + case *NodeAlias: + var a [32]byte + if _, err := io.ReadFull(r, a[:]); err != nil { + return err + } + + alias, err := NewNodeAlias(string(a[:])) + if err != nil { + return err + } + + *e = alias case *ShortChanIDEncoding: var b [1]uint8 if _, err := r.Read(b[:]); err != nil { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 3c910e47..013295e3 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -41,6 +41,17 @@ var ( _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) ) +const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randAlias(r *rand.Rand) NodeAlias { + var a NodeAlias + for i := range a { + a[i] = letterBytes[r.Intn(len(letterBytes))] + } + + return a +} + func randPubKey() (*btcec.PublicKey, error) { priv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { @@ -551,17 +562,11 @@ func TestLightningWireProtocol(t *testing.T) { v[0] = reflect.ValueOf(req) }, MsgNodeAnnouncement: func(v []reflect.Value, r *rand.Rand) { - var a [32]byte - if _, err := r.Read(a[:]); err != nil { - t.Fatalf("unable to generate alias: %v", err) - return - } - var err error req := NodeAnnouncement{ Features: randRawFeatureVector(r), Timestamp: uint32(r.Int31()), - Alias: a, + Alias: randAlias(r), RGBColor: color.RGBA{ R: uint8(r.Int31()), G: uint8(r.Int31()), diff --git a/lnwire/node_announcement.go b/lnwire/node_announcement.go index 9272bcef..42c1d5e1 100644 --- a/lnwire/node_announcement.go +++ b/lnwire/node_announcement.go @@ -28,6 +28,17 @@ func (e ErrUnknownAddrType) Error() string { return fmt.Sprintf("unknown address type: %v", e.addrType) } +// ErrInvalidNodeAlias is an error returned if a node alias we parse on the +// wire is invalid, as in it has non UTF-8 characters. +type ErrInvalidNodeAlias struct{} + +// Error returns a human readable string describing the error. +// +// NOTE: implements the error interface. +func (e ErrInvalidNodeAlias) Error() string { + return "node alias has non-utf8 characters" +} + // NodeAlias a hex encoded UTF-8 string that may be displayed as an alternative // to the node's ID. Notice that aliases are not unique and may be freely // chosen by the node operators. @@ -39,11 +50,12 @@ func NewNodeAlias(s string) (NodeAlias, error) { var n NodeAlias if len(s) > 32 { - return n, fmt.Errorf("alias too large: max is %v, got %v", 32, len(s)) + return n, fmt.Errorf("alias too large: max is %v, got %v", 32, + len(s)) } if !utf8.ValidString(s) { - return n, fmt.Errorf("invalid utf8 string") + return n, &ErrInvalidNodeAlias{} } copy(n[:], []byte(s)) @@ -117,7 +129,7 @@ func (a *NodeAnnouncement) Decode(r io.Reader, pver uint32) error { &a.Timestamp, &a.NodeID, &a.RGBColor, - a.Alias[:], + &a.Alias, &a.Addresses, ) if err != nil { @@ -149,7 +161,7 @@ func (a *NodeAnnouncement) Encode(w io.Writer, pver uint32) error { a.Timestamp, a.NodeID, a.RGBColor, - a.Alias[:], + a.Alias, a.Addresses, a.ExtraOpaqueData, )