diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 427a9c2a..73480780 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -180,7 +180,10 @@ func writeElement(w io.Writer, element interface{}) error { if _, err := w.Write(e[:]); err != nil { return err } - + case FailCode: + if err := writeElement(w, uint16(e)); err != nil { + return err + } case ShortChannelID: // Check that field fit in 3 bytes and write the blockHeight if e.BlockHeight > ((1 << 24) - 1) { @@ -440,6 +443,10 @@ func readElement(r io.Reader, element interface{}) error { Hash: *hash, Index: uint32(index), } + case *FailCode: + if err := readElement(r, (*uint16)(e)); err != nil { + return err + } case *ChannelID: if _, err := io.ReadFull(r, e[:]); err != nil { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index ebfefcfc..20f9e054 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -529,6 +529,13 @@ func TestLightningWireProtocol(t *testing.T) { return mainScenario(&m) }, }, + { + + msgType: MsgUpdateFailMalformedHTLC, + scenario: func(m UpdateFailMalformedHTLC) bool { + return mainScenario(&m) + }, + }, { msgType: MsgChannelAnnouncement, scenario: func(m ChannelAnnouncement) bool { diff --git a/lnwire/message.go b/lnwire/message.go index f12fa2f7..1cb2ce04 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -39,6 +39,7 @@ const ( MsgUpdateFailHTLC = 131 MsgCommitSig = 132 MsgRevokeAndAck = 133 + MsgUpdateFailMalformedHTLC = 135 MsgUpdateFee = 137 MsgChannelAnnouncement = 256 MsgNodeAnnouncement = 257 @@ -75,6 +76,8 @@ func (t MessageType) String() string { return "CommitSig" case MsgRevokeAndAck: return "RevokeAndAck" + case MsgUpdateFailMalformedHTLC: + return "UpdateFailMalformedHTLC" case MsgError: return "Error" case MsgChannelAnnouncement: @@ -161,6 +164,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { msg = &RevokeAndAck{} case MsgUpdateFee: msg = &UpdateFee{} + case MsgUpdateFailMalformedHTLC: + msg = &UpdateFailMalformedHTLC{} case MsgError: msg = &Error{} case MsgChannelAnnouncement: diff --git a/lnwire/update_fail_malformed_htlc.go b/lnwire/update_fail_malformed_htlc.go new file mode 100644 index 00000000..ef5ff3e1 --- /dev/null +++ b/lnwire/update_fail_malformed_htlc.go @@ -0,0 +1,74 @@ +package lnwire + +import ( + "crypto/sha256" + "io" +) + +// UpdateFailMalformedHTLC is sent by either the payment forwarder or by payment +// receiver to the payment sender in order to notify it that the onion blob +// can't be parsed. For that reason we send this message instead of obfuscate +// the onion failure. +type UpdateFailMalformedHTLC struct { + // ChanID is the particular active channel that this + // UpdateFailMalformedHTLC is bound to. + ChanID ChannelID + + // ID references which HTLC on the remote node's commitment transaction + // has timed out. + ID uint64 + + // ShaOnionBlob hash of the onion blob on which can't be parsed by the + // node in the payment path. + ShaOnionBlob [sha256.Size]byte + + // FailureCode the exact reason why onion blob haven't been parsed. + FailureCode FailCode +} + +// A compile time check to ensure UpdateFailMalformedHTLC implements the lnwire.Message +// interface. +var _ Message = (*UpdateFailMalformedHTLC)(nil) + +// Decode deserializes a serialized UpdateFailMalformedHTLC message stored in the passed +// io.Reader observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) Decode(r io.Reader, pver uint32) error { + return readElements(r, + &c.ChanID, + &c.ID, + c.ShaOnionBlob[:], + &c.FailureCode, + ) +} + +// Encode serializes the target UpdateFailMalformedHTLC into the passed io.Writer observing +// the protocol version specified. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) Encode(w io.Writer, pver uint32) error { + return writeElements(w, + c.ChanID, + c.ID, + c.ShaOnionBlob[:], + c.FailureCode, + ) +} + +// MsgType returns the integer uniquely identifying this message type on the +// wire. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) MsgType() MessageType { + return MsgUpdateFailMalformedHTLC +} + +// MaxPayloadLength returns the maximum allowed payload size for a UpdateFailMalformedHTLC +// complete message observing the specified protocol version. +// +// This is part of the lnwire.Message interface. +func (c *UpdateFailMalformedHTLC) MaxPayloadLength(uint32) uint32 { + // 32 + 8 + 32 + 2 + return 74 +}