diff --git a/zpay32/invoice_internal_test.go b/zpay32/invoice_internal_test.go index 5d0c2cf1..8089c69d 100644 --- a/zpay32/invoice_internal_test.go +++ b/zpay32/invoice_internal_test.go @@ -777,3 +777,75 @@ func TestParseRouteHint(t *testing.T) { } } } + +// TestParseTaggedFields checks that tagged field data is correctly parsed or +// errors as expected. +func TestParseTaggedFields(t *testing.T) { + t.Parallel() + + netParams := &chaincfg.SimNetParams + + tests := []struct { + name string + data []byte + wantErr error + }{ + { + name: "nil data", + data: nil, + }, + { + name: "empty data", + data: []byte{}, + }, + { + // Type 0xff cannot be encoded in a single 5-bit + // element, so it's technically invalid but + // parseTaggedFields doesn't error on non-5bpp + // compatible codes so we can use a code in tests which + // will never become known in the future. + name: "valid unknown field", + data: []byte{0xff, 0x00, 0x00}, + }, + { + name: "unknown field valid data", + data: []byte{0xff, 0x00, 0x01, 0xab}, + }, + { + name: "only type specified", + data: []byte{0x0d}, + wantErr: ErrBrokenTaggedField, + }, + { + name: "not enough bytes for len", + data: []byte{0x0d, 0x00}, + wantErr: ErrBrokenTaggedField, + }, + { + name: "no bytes after len", + data: []byte{0x0d, 0x00, 0x01}, + wantErr: ErrInvalidFieldLength, + }, + { + name: "not enough bytes after len", + data: []byte{0x0d, 0x00, 0x02, 0x01}, + wantErr: ErrInvalidFieldLength, + }, + { + name: "not enough bytes after len with unknown type", + data: []byte{0xff, 0x00, 0x02, 0x01}, + wantErr: ErrInvalidFieldLength, + }, + } + for _, tc := range tests { + tc := tc // pin + t.Run(tc.name, func(t *testing.T) { + var invoice Invoice + gotErr := parseTaggedFields(&invoice, tc.data, netParams) + if tc.wantErr != gotErr { + t.Fatalf("Unexpected error. want=%v got=%v", + tc.wantErr, gotErr) + } + }) + } +} diff --git a/zpay32/invoice_test.go b/zpay32/invoice_test.go index 565c3fcf..eec242f7 100644 --- a/zpay32/invoice_test.go +++ b/zpay32/invoice_test.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "fmt" "reflect" + "strings" "testing" "time" @@ -843,6 +844,56 @@ func TestMaxInvoiceLength(t *testing.T) { } } +// TestInvoiceChecksumMalleability ensures that the malleability of the +// checksum in bech32 strings cannot cause a signature to become valid and +// therefore cause a wrong destination to be decoded for invoices where the +// destination is extracted from the signature. +func TestInvoiceChecksumMalleability(t *testing.T) { + privKeyHex := "a50f3bdf9b6c4b1fdd7c51a8bbf4b5855cf381f413545ed155c0282f4412a1b1" + privKeyBytes, _ := hex.DecodeString(privKeyHex) + chain := &chaincfg.SimNetParams + var payHash [32]byte + ts := time.Unix(0, 0) + + privKey, _ := btcec.PrivKeyFromBytes(btcec.S256(), privKeyBytes) + msgSigner := MessageSigner{ + SignCompact: func(hash []byte) ([]byte, error) { + return btcec.SignCompact(btcec.S256(), privKey, hash, true) + }, + } + opts := []func(*Invoice){Description("test")} + invoice, err := NewInvoice(chain, payHash, ts, opts...) + if err != nil { + t.Fatal(err) + } + + encoded, err := invoice.Encode(msgSigner) + if err != nil { + t.Fatal(err) + } + + // Changing a bech32 string which checksum ends in "p" to "(q*)p" can + // cause the checksum to return as a valid bech32 string _but_ the + // signature field immediately preceding it would be mutaded. In rare + // cases (about 3%) it is still seen as a valid signature and public + // key recovery causes a different node than the originally intended + // one to be derived. + // + // We thus modify the checksum here and verify the invoice gets broken + // enough that it fails to decode. + if !strings.HasSuffix(encoded, "p") { + t.Logf("Invoice: %s", encoded) + t.Fatalf("Generated invoice checksum does not end in 'p'") + } + encoded = encoded[:len(encoded)-1] + "qp" + + _, err = Decode(encoded, chain) + if err == nil { + t.Fatalf("Did not get expected error when decoding invoice") + } + +} + func compareInvoices(expected, actual *Invoice) error { if !reflect.DeepEqual(expected.Net, actual.Net) { return fmt.Errorf("expected net %v, got %v",