package tlv import ( "bytes" "errors" "io" "io/ioutil" "math" ) // MaxRecordSize is the maximum size of a particular record that will be parsed // by a stream decoder. This value is currently chosen to the be equal to the // maximum message size permitted by BOLT 1, as no record should be bigger than // an entire message. const MaxRecordSize = 65535 // 65KB // ErrStreamNotCanonical signals that a decoded stream does not contain records // sorting by monotonically-increasing type. var ErrStreamNotCanonical = errors.New("tlv stream is not canonical") // ErrRecordTooLarge signals that a decoded record has a length that is too // long to parse. var ErrRecordTooLarge = errors.New("record is too large") // Stream defines a TLV stream that can be used for encoding or decoding a set // of TLV Records. type Stream struct { records []Record buf [8]byte } // NewStream creates a new TLV Stream given an encoding codec, a decoding codec, // and a set of known records. func NewStream(records ...Record) (*Stream, error) { // Assert that the ordering of the Records is canonical and appear in // ascending order of type. var ( min Type overflow bool ) for _, record := range records { if overflow || record.typ < min { return nil, ErrStreamNotCanonical } if record.encoder == nil { record.encoder = ENOP } if record.decoder == nil { record.decoder = DNOP } if record.typ == math.MaxUint64 { overflow = true } min = record.typ + 1 } return &Stream{ records: records, }, nil } // MustNewStream creates a new TLV Stream given an encoding codec, a decoding // codec, and a set of known records. If an error is encountered in creating the // stream, this method will panic instead of returning the error. func MustNewStream(records ...Record) *Stream { stream, err := NewStream(records...) if err != nil { panic(err.Error()) } return stream } // Encode writes a Stream to the passed io.Writer. Each of the Records known to // the Stream is written in ascending order of their type so as to be canonical. // // The stream is constructed by concatenating the individual, serialized Records // where each record has the following format: // [varint: type] // [varint: length] // [length: value] // // An error is returned if the io.Writer fails to accept bytes from the // encoding, and nothing else. The ordering of the Records is asserted upon the // creation of a Stream, and thus the output will be by definition canonical. func (s *Stream) Encode(w io.Writer) error { // Iterate through all known records, if any, serializing each record's // type, length and value. for i := range s.records { rec := &s.records[i] // Write the record's type as a varint. err := WriteVarInt(w, uint64(rec.typ), &s.buf) if err != nil { return err } // Write the record's length as a varint. err = WriteVarInt(w, rec.Size(), &s.buf) if err != nil { return err } // Encode the current record's value using the stream's codec. err = rec.encoder(w, rec.value, &s.buf) if err != nil { return err } } return nil } // Decode deserializes TLV Stream from the passed io.Reader. The Stream will // inspect each record that is parsed and check to see if it has a corresponding // Record to facilitate deserialization of that field. If the record is unknown, // the Stream will discard the record's bytes and proceed to the subsequent // record. // // Each record has the following format: // [varint: type] // [varint: length] // [length: value] // // A series of (possibly zero) records are concatenated into a stream, this // example contains two records: // // (t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff) // (t: 0x02, l: 0x01, v: 0x01) // // This method asserts that the byte stream is canonical, namely that each // record is unique and that all records are sorted in ascending order. An // ErrNotCanonicalStream error is returned if the encoded TLV stream is not. // // We permit an io.EOF error only when reading the type byte which signals that // the last record was read cleanly and we should stop parsing. All other io.EOF // or io.ErrUnexpectedEOF errors are returned. func (s *Stream) Decode(r io.Reader) error { _, err := s.decode(r, nil) return err } // DecodeWithParsedTypes is identical to Decode, but if successful, returns a // TypeMap containing the types of all records that were decoded or ignored from // the stream. func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeMap, error) { return s.decode(r, make(TypeMap)) } // decode is a helper function that performs the basis of stream decoding. If // the caller needs the set of parsed types, it must provide an initialized // parsedTypes, otherwise the returned TypeMap will be nil. func (s *Stream) decode(r io.Reader, parsedTypes TypeMap) (TypeMap, error) { var ( typ Type min Type recordIdx int overflow bool ) // Iterate through all possible type identifiers. As types are read from // the io.Reader, min will skip forward to the last read type. for { // Read the next varint type. t, err := ReadVarInt(r, &s.buf) switch { // We'll silence an EOF when zero bytes remain, meaning the // stream was cleanly encoded. case err == io.EOF: return parsedTypes, nil // Other unexpected errors. case err != nil: return nil, err } typ = Type(t) // Assert that this type is greater than any previously read. // If we've already overflowed and we parsed another type, the // stream is not canonical. This check prevents us from accepts // encodings that have duplicate records or from accepting an // unsorted series. if overflow || typ < min { return nil, ErrStreamNotCanonical } // Read the varint length. length, err := ReadVarInt(r, &s.buf) switch { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: return nil, err } // Place a soft limit on the size of a sane record, which // prevents malicious encoders from causing us to allocate an // unbounded amount of memory when decoding variable-sized // fields. if length > MaxRecordSize { return nil, ErrRecordTooLarge } // Search the records known to the stream for this type. We'll // begin the search and recordIdx and walk forward until we find // it or the next record's type is larger. rec, newIdx, ok := s.getRecord(typ, recordIdx) switch { // We know of this record type, proceed to decode the value. // This method asserts that length bytes are read in the // process, and returns an error if the number of bytes is not // exactly length. case ok: err := rec.decoder(r, rec.value, &s.buf, length) switch { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: return nil, err } // Record the successfully decoded type if the caller // provided an initialized TypeMap. if parsedTypes != nil { parsedTypes[typ] = nil } // Otherwise, the record type is unknown and is odd, discard the // number of bytes specified by length. default: // If the caller provided an initialized TypeMap, record // the encoded bytes. var b *bytes.Buffer writer := ioutil.Discard if parsedTypes != nil { b = bytes.NewBuffer(make([]byte, 0, length)) writer = b } _, err := io.CopyN(writer, r, int64(length)) switch { // We'll convert any EOFs to ErrUnexpectedEOF, since this // results in an invalid record. case err == io.EOF: return nil, io.ErrUnexpectedEOF // Other unexpected errors. case err != nil: return nil, err } if parsedTypes != nil { parsedTypes[typ] = b.Bytes() } } // Update our record index so that we can begin our next search // from where we left off. recordIdx = newIdx // If we've parsed the largest possible type, the next loop will // overflow back to zero. However, we need to attempt parsing // the next type to ensure that the stream is empty. if typ == math.MaxUint64 { overflow = true } // Finally, set our lower bound on the next accepted type. min = typ + 1 } } // getRecord searches for a record matching typ known to the stream. The boolean // return value indicates whether the record is known to the stream. The integer // return value carries the index from where getRecord should be invoked on the // subsequent call. The first call to getRecord should always use an idx of 0. func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) { for idx < len(s.records) { record := s.records[idx] switch { // Found target record, return it to the caller. The next index // returned points to the immediately following record. case record.typ == typ: return record, idx + 1, true // This record's type is lower than the target. Advance our // index and continue to the next record which will have a // strictly higher type. case record.typ < typ: idx++ continue // This record's type is larger than the target, hence we have // no record matching the current type. Return the current index // so that we can start our search from here when processing the // next tlv record. default: return Record{}, idx, false } } // All known records are exhausted. return Record{}, idx, false }