Merge pull request #2462 from Roasbeef/fix-length-update

lnwire: ensure we precisely encode the length for onion errors w/ cha…
This commit is contained in:
Wilmer Paulino 2019-01-11 16:12:04 -08:00 committed by GitHub
commit 55b580f2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 30 deletions

@ -665,12 +665,7 @@ func (f *FailAmountBelowMinimum) Encode(w io.Writer, pver uint32) error {
return err return err
} }
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) return writeOnionErrorChanUpdate(w, &f.Update, pver)
if err != nil {
return err
}
return f.Update.Encode(w, pver)
} }
// FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we // FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we
@ -738,12 +733,7 @@ func (f *FailFeeInsufficient) Encode(w io.Writer, pver uint32) error {
return err return err
} }
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) return writeOnionErrorChanUpdate(w, &f.Update, pver)
if err != nil {
return err
}
return f.Update.Encode(w, pver)
} }
// FailIncorrectCltvExpiry is returned if outgoing cltv value does not match // FailIncorrectCltvExpiry is returned if outgoing cltv value does not match
@ -811,12 +801,7 @@ func (f *FailIncorrectCltvExpiry) Encode(w io.Writer, pver uint32) error {
return err return err
} }
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) return writeOnionErrorChanUpdate(w, &f.Update, pver)
if err != nil {
return err
}
return f.Update.Encode(w, pver)
} }
// FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them // FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them
@ -869,12 +854,7 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error {
// //
// NOTE: Part of the Serializable interface. // NOTE: Part of the Serializable interface.
func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error { func (f *FailExpiryTooSoon) Encode(w io.Writer, pver uint32) error {
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) return writeOnionErrorChanUpdate(w, &f.Update, pver)
if err != nil {
return err
}
return f.Update.Encode(w, pver)
} }
// FailChannelDisabled is returned if the channel is disabled, we tell them the // FailChannelDisabled is returned if the channel is disabled, we tell them the
@ -942,12 +922,7 @@ func (f *FailChannelDisabled) Encode(w io.Writer, pver uint32) error {
return err return err
} }
err := WriteElement(w, uint16(f.Update.MaxPayloadLength(pver))) return writeOnionErrorChanUpdate(w, &f.Update, pver)
if err != nil {
return err
}
return f.Update.Encode(w, pver)
} }
// FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not // FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not
@ -1231,3 +1206,32 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) {
return nil, errors.Errorf("unknown error code: %v", code) return nil, errors.Errorf("unknown error code: %v", code)
} }
} }
// writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error
// format. The format is that we first write out the true serialized length of
// the channel update, followed by the serialized channel update itself.
func writeOnionErrorChanUpdate(w io.Writer, chanUpdate *ChannelUpdate,
pver uint32) error {
// First, we encode the channel update in a temporary buffer in order
// to get the exact serialized size.
var b bytes.Buffer
if err := chanUpdate.Encode(&b, pver); err != nil {
return err
}
// Now that we know the size, we can write the length out in the main
// writer.
updateLen := b.Len()
if err := WriteElement(w, uint16(updateLen)); err != nil {
return err
}
// With the length written, we'll then write out the serialized channel
// update.
if _, err := w.Write(b.Bytes()); err != nil {
return err
}
return nil
}

@ -131,3 +131,38 @@ func TestChannelUpdateCompatabilityParsing(t *testing.T) {
t.Fatalf("mismatched channel updates: %v", err) t.Fatalf("mismatched channel updates: %v", err)
} }
} }
// TestWriteOnionErrorChanUpdate tests that we write an exact size for the
// channel update in order to be more compliant with the parsers of other
// implementations.
func TestWriteOnionErrorChanUpdate(t *testing.T) {
t.Parallel()
// First, we'll write out the raw channel update so we can obtain the
// raw serialized length.
var b bytes.Buffer
update := testChannelUpdate
if err := update.Encode(&b, 0); err != nil {
t.Fatalf("unable to write update: %v", err)
}
trueUpdateLength := b.Len()
// Next, we'll use the function to encode the update as we would in a
// onion error message.
var errorBuf bytes.Buffer
err := writeOnionErrorChanUpdate(&errorBuf, &update, 0)
if err != nil {
t.Fatalf("unable to encode onion error: %v", err)
}
// Finally, read the length encoded and ensure that it matches the raw
// length.
var encodedLen uint16
if err := ReadElement(&errorBuf, &encodedLen); err != nil {
t.Fatalf("unable to read len: %v", err)
}
if uint16(trueUpdateLength) != encodedLen {
t.Fatalf("wrong length written: expected %v, got %v",
trueUpdateLength, encodedLen)
}
}