diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 1763d0ba..ffc7e1af 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -665,12 +665,7 @@ func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error { return err } - err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) - if err != nil { - return err - } - - return f.Update.Encode(w, pver) + return writeOnionErrorChanUpdate(w, &f.Update, pver) } // FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we @@ -738,12 +733,7 @@ func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error { return err } - err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) - if err != nil { - return err - } - - return f.Update.Encode(w, pver) + return writeOnionErrorChanUpdate(w, &f.Update, pver) } // FailIncorrectCltvExpiry is returned if outgoing cltv value does not match @@ -811,12 +801,7 @@ func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error { return err } - err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) - if err != nil { - return err - } - - return f.Update.Encode(w, pver) + return writeOnionErrorChanUpdate(w, &f.Update, pver) } // FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them @@ -869,12 +854,7 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { // // NOTE: Part of the Serializable interface. func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error { - err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) - if err != nil { - return err - } - - return f.Update.Encode(w, pver) + return writeOnionErrorChanUpdate(w, &f.Update, pver) } // FailChannelDisabled is returned if the channel is disabled, we tell them the @@ -942,12 +922,7 @@ func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error { return err } - err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) - if err != nil { - return err - } - - return f.Update.Encode(w, pver) + return writeOnionErrorChanUpdate(w, &f.Update, pver) } // FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not @@ -1231,3 +1206,32 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { return nil, errors.Errorf("unknown error code: %v", code) } } + +// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error +// format. The format is that we first write out the true serialized length of +// the channel update, followed by the serialized channel update itself. +func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate, + pver uint32) error { + + // First, we encode the channel update in a temporary buffer in order + // to get the exact serialized size. + var b bytes.Buffer + if err := chanUpdate.Encode(&b, pver); err != nil { + return err + } + + // Now that we know the size, we can write the length out in the main + // writer. + updateLen := b.Len() + if err := WriteElement(w, uint16(updateLen)); err != nil { + return err + } + + // With the length written, we'll then write out the serialized channel + // update. + if _, err := w.Write(b.Bytes()); err != nil { + return err + } + + return nil +} diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index d59281a2..62b92766 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -131,3 +131,38 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) { t.Fatalf("mismatched channel updates: %v", err) } } + +// TestWriteOnionErrorChanUpdate tests that we write an exact size for the +// channel update in order to be more compliant with the parsers of other +// implementations. +func TestWriteOnionErrorChanUpdate(t *testing.T) { + t.Parallel() + + // First, we'll write out the raw channel update so we can obtain the + // raw serialized length. + var b bytes.Buffer + update := testChannelUpdate + if err := update.Encode(&b, 0); err != nil { + t.Fatalf("unable to write update: %v", err) + } + trueUpdateLength := b.Len() + + // Next, we'll use the function to encode the update as we would in a + // onion error message. + var errorBuf bytes.Buffer + err := writeOnionErrorChanUpdate(&errorBuf, &update, 0) + if err != nil { + t.Fatalf("unable to encode onion error: %v", err) + } + + // Finally, read the length encoded and ensure that it matches the raw + // length. + var encodedLen uint16 + if err := ReadElement(&errorBuf, &encodedLen); err != nil { + t.Fatalf("unable to read len: %v", err) + } + if uint16(trueUpdateLength) != encodedLen { + t.Fatalf("wrong length written: expected %v, got %v", + trueUpdateLength, encodedLen) + } +}