diff --git a/tlv/truncated.go b/tlv/truncated.go index 8ed9cb0d..930a2cce 100644 --- a/tlv/truncated.go +++ b/tlv/truncated.go @@ -40,6 +40,15 @@ func ETUint16(w io.Writer, val interface{}, buf *[8]byte) error { return NewTypeForEncodingErr(val, "uint16") } +// ETUint16T is an Encoder for truncated uint16 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint16. +func ETUint16T(w io.Writer, val uint16, buf *[8]byte) error { + binary.BigEndian.PutUint16(buf[:2], val) + numZeros := numLeadingZeroBytes16(val) + _, err := w.Write(buf[numZeros:2]) + return err +} + // DTUint16 is an Decoder for truncated uint16 values, where leading zeros will // be resurrected. An error is returned if val is not a *uint16. func DTUint16(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { @@ -92,6 +101,15 @@ func ETUint32(w io.Writer, val interface{}, buf *[8]byte) error { return NewTypeForEncodingErr(val, "uint32") } +// ETUint32T is an Encoder for truncated uint32 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint32. +func ETUint32T(w io.Writer, val uint32, buf *[8]byte) error { + binary.BigEndian.PutUint32(buf[:4], val) + numZeros := numLeadingZeroBytes32(val) + _, err := w.Write(buf[numZeros:4]) + return err +} + // DTUint32 is an Decoder for truncated uint32 values, where leading zeros will // be resurrected. An error is returned if val is not a *uint32. func DTUint32(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { @@ -154,6 +172,15 @@ func ETUint64(w io.Writer, val interface{}, buf *[8]byte) error { return NewTypeForEncodingErr(val, "uint64") } +// ETUint64T is an Encoder for truncated uint64 values, where leading zeros will +// be omitted. An error is returned if val is not a *uint64. +func ETUint64T(w io.Writer, val uint64, buf *[8]byte) error { + binary.BigEndian.PutUint64(buf[:], val) + numZeros := numLeadingZeroBytes64(val) + _, err := w.Write(buf[numZeros:]) + return err +} + // DTUint64 is an Decoder for truncated uint64 values, where leading zeros will // be resurrected. An error is returned if val is not a *uint64. func DTUint64(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { diff --git a/tlv/truncated_test.go b/tlv/truncated_test.go index d2a34562..eb0f83a7 100644 --- a/tlv/truncated_test.go +++ b/tlv/truncated_test.go @@ -60,6 +60,8 @@ func TestSizeTUint16(t *testing.T) { func TestTUint16(t *testing.T) { var buf [8]byte for _, test := range tuint16Tests { + test := test + if len(test.bytes) != int(test.size) { t.Fatalf("invalid test case, "+ "len(bytes)[%d] != size[%d]", @@ -68,6 +70,7 @@ func TestTUint16(t *testing.T) { name := fmt.Sprintf("0x%x", test.value) t.Run(name, func(t *testing.T) { + // Test generic encoder. var b bytes.Buffer err := tlv.ETUint16(&b, &test.value, &buf) if err != nil { @@ -80,6 +83,19 @@ func TestTUint16(t *testing.T) { test.bytes, b.Bytes()) } + // Test non-generic encoder. + var b2 bytes.Buffer + err = tlv.ETUint16T(&b2, test.value, &buf) + if err != nil { + t.Fatalf("unable to encode tuint16: %v", err) + } + + if !bytes.Equal(b2.Bytes(), test.bytes) { + t.Fatalf("encoding mismatch, "+ + "expected: %x, got: %x", + test.bytes, b2.Bytes()) + } + var value uint16 r := bytes.NewReader(b.Bytes()) err = tlv.DTUint16(r, &value, &buf, test.size) @@ -168,6 +184,8 @@ func TestSizeTUint32(t *testing.T) { func TestTUint32(t *testing.T) { var buf [8]byte for _, test := range tuint32Tests { + test := test + if len(test.bytes) != int(test.size) { t.Fatalf("invalid test case, "+ "len(bytes)[%d] != size[%d]", @@ -176,6 +194,7 @@ func TestTUint32(t *testing.T) { name := fmt.Sprintf("0x%x", test.value) t.Run(name, func(t *testing.T) { + // Test generic encoder. var b bytes.Buffer err := tlv.ETUint32(&b, &test.value, &buf) if err != nil { @@ -188,6 +207,19 @@ func TestTUint32(t *testing.T) { test.bytes, b.Bytes()) } + // Test non-generic encoder. + var b2 bytes.Buffer + err = tlv.ETUint32T(&b2, test.value, &buf) + if err != nil { + t.Fatalf("unable to encode tuint32: %v", err) + } + + if !bytes.Equal(b2.Bytes(), test.bytes) { + t.Fatalf("encoding mismatch, "+ + "expected: %x, got: %x", + test.bytes, b2.Bytes()) + } + var value uint32 r := bytes.NewReader(b.Bytes()) err = tlv.DTUint32(r, &value, &buf, test.size) @@ -322,6 +354,8 @@ func TestSizeTUint64(t *testing.T) { func TestTUint64(t *testing.T) { var buf [8]byte for _, test := range tuint64Tests { + test := test + if len(test.bytes) != int(test.size) { t.Fatalf("invalid test case, "+ "len(bytes)[%d] != size[%d]", @@ -330,6 +364,7 @@ func TestTUint64(t *testing.T) { name := fmt.Sprintf("0x%x", test.value) t.Run(name, func(t *testing.T) { + // Test generic encoder. var b bytes.Buffer err := tlv.ETUint64(&b, &test.value, &buf) if err != nil { @@ -342,6 +377,19 @@ func TestTUint64(t *testing.T) { test.bytes, b.Bytes()) } + // Test non-generic encoder. + var b2 bytes.Buffer + err = tlv.ETUint64T(&b2, test.value, &buf) + if err != nil { + t.Fatalf("unable to encode tuint64: %v", err) + } + + if !bytes.Equal(b2.Bytes(), test.bytes) { + t.Fatalf("encoding mismatch, "+ + "expected: %x, got: %x", + test.bytes, b2.Bytes()) + } + var value uint64 r := bytes.NewReader(b.Bytes()) err = tlv.DTUint64(r, &value, &buf, test.size)