Merge pull request #2462 from Roasbeef/fix-length-update
lnwire: ensure we precisely encode the length for onion errors w/ cha…
This commit is contained in:
commit
55b580f2b8
@ -665,12 +665,7 @@ func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver)))
|
return writeOnionErrorChanUpdate(w, &f.Update, pver)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Update.Encode(w, pver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we
|
// FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we
|
||||||
@ -738,12 +733,7 @@ func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver)))
|
return writeOnionErrorChanUpdate(w, &f.Update, pver)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Update.Encode(w, pver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FailIncorrectCltvExpiry is returned if outgoing cltv value does not match
|
// FailIncorrectCltvExpiry is returned if outgoing cltv value does not match
|
||||||
@ -811,12 +801,7 @@ func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver)))
|
return writeOnionErrorChanUpdate(w, &f.Update, pver)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Update.Encode(w, pver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them
|
// FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them
|
||||||
@ -869,12 +854,7 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error {
|
|||||||
//
|
//
|
||||||
// NOTE: Part of the Serializable interface.
|
// NOTE: Part of the Serializable interface.
|
||||||
func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error {
|
func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error {
|
||||||
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver)))
|
return writeOnionErrorChanUpdate(w, &f.Update, pver)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Update.Encode(w, pver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FailChannelDisabled is returned if the channel is disabled, we tell them the
|
// FailChannelDisabled is returned if the channel is disabled, we tell them the
|
||||||
@ -942,12 +922,7 @@ func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver)))
|
return writeOnionErrorChanUpdate(w, &f.Update, pver)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return f.Update.Encode(w, pver)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not
|
// FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not
|
||||||
@ -1231,3 +1206,32 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
|
|||||||
return nil, errors.Errorf("unknown error code: %v", code)
|
return nil, errors.Errorf("unknown error code: %v", code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error
|
||||||
|
// format. The format is that we first write out the true serialized length of
|
||||||
|
// the channel update, followed by the serialized channel update itself.
|
||||||
|
func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate,
|
||||||
|
pver uint32) error {
|
||||||
|
|
||||||
|
// First, we encode the channel update in a temporary buffer in order
|
||||||
|
// to get the exact serialized size.
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := chanUpdate.Encode(&b, pver); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we know the size, we can write the length out in the main
|
||||||
|
// writer.
|
||||||
|
updateLen := b.Len()
|
||||||
|
if err := WriteElement(w, uint16(updateLen)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the length written, we'll then write out the serialized channel
|
||||||
|
// update.
|
||||||
|
if _, err := w.Write(b.Bytes()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -131,3 +131,38 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) {
|
|||||||
t.Fatalf("mismatched channel updates: %v", err)
|
t.Fatalf("mismatched channel updates: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestWriteOnionErrorChanUpdate tests that we write an exact size for the
|
||||||
|
// channel update in order to be more compliant with the parsers of other
|
||||||
|
// implementations.
|
||||||
|
func TestWriteOnionErrorChanUpdate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// First, we'll write out the raw channel update so we can obtain the
|
||||||
|
// raw serialized length.
|
||||||
|
var b bytes.Buffer
|
||||||
|
update := testChannelUpdate
|
||||||
|
if err := update.Encode(&b, 0); err != nil {
|
||||||
|
t.Fatalf("unable to write update: %v", err)
|
||||||
|
}
|
||||||
|
trueUpdateLength := b.Len()
|
||||||
|
|
||||||
|
// Next, we'll use the function to encode the update as we would in a
|
||||||
|
// onion error message.
|
||||||
|
var errorBuf bytes.Buffer
|
||||||
|
err := writeOnionErrorChanUpdate(&errorBuf, &update, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to encode onion error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, read the length encoded and ensure that it matches the raw
|
||||||
|
// length.
|
||||||
|
var encodedLen uint16
|
||||||
|
if err := ReadElement(&errorBuf, &encodedLen); err != nil {
|
||||||
|
t.Fatalf("unable to read len: %v", err)
|
||||||
|
}
|
||||||
|
if uint16(trueUpdateLength) != encodedLen {
|
||||||
|
t.Fatalf("wrong length written: expected %v, got %v",
|
||||||
|
trueUpdateLength, encodedLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user