package tlv

import (
	"bytes"
	"fmt"
	"io"
	"sort"

	"github.com/btcsuite/btcd/btcec"
)

// Type is an 64-bit identifier for a TLV Record.
type Type uint64

// TypeMap is a map of parsed Types. The map values are byte slices. If the byte
// slice is nil, the type was successfully parsed. Otherwise the value is byte
// slice containing the encoded data.
type TypeMap map[Type][]byte

// Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.
type Encoder func(w io.Writer, val interface{}, buf *[8]byte) error

// Decoder is a signature for methods that can decode TLV values. An error
// should be returned if the Decoder cannot support the underlying type of val.
// The provided scratch buffer must be non-nil.
type Decoder func(r io.Reader, val interface{}, buf *[8]byte, l uint64) error

// ENOP is an encoder that doesn't modify the io.Writer and never fails.
func ENOP(io.Writer, interface{}, *[8]byte) error { return nil }

// DNOP is an encoder that doesn't modify the io.Reader and never fails.
func DNOP(io.Reader, interface{}, *[8]byte, uint64) error { return nil }

// SizeFunc is a function that can compute the length of a given field. Since
// the size of the underlying field can change, this allows the size of the
// field to be evaluated at the time of encoding.
type SizeFunc func() uint64

// SizeVarBytes returns a SizeFunc that can compute the length of a byte slice.
func SizeVarBytes(e *[]byte) SizeFunc {
	return func() uint64 {
		return uint64(len(*e))
	}
}

// RecorderProducer is an interface for objects that can produce a Record object
// capable of encoding and/or decoding the RecordProducer as a Record.
type RecordProducer interface {
	// Record returns a Record that can be used to encode or decode the
	// backing object.
	Record() Record
}

// Record holds the required information to encode or decode a TLV record.
type Record struct {
	value      interface{}
	typ        Type
	staticSize uint64
	sizeFunc   SizeFunc
	encoder    Encoder
	decoder    Decoder
}

// Size returns the size of the Record's value. If no static size is known, the
// dynamic size will be evaluated.
func (f *Record) Size() uint64 {
	if f.sizeFunc == nil {
		return f.staticSize
	}

	return f.sizeFunc()
}

// Type returns the type of the underlying TLV record.
func (f *Record) Type() Type {
	return f.typ
}

// Encode writes out the TLV record to the passed writer. This is useful when a
// caller wants to obtain the raw encoding of a *single* TLV record, outside
// the context of the Stream struct.
func (f *Record) Encode(w io.Writer) error {
	var b [8]byte

	return f.encoder(w, f.value, &b)
}

// Decode read in the TLV record from the passed reader. This is useful when a
// caller wants decode a *single* TLV record, outside the context of the Stream
// struct.
func (f *Record) Decode(r io.Reader, l uint64) error {
	var b [8]byte
	return f.decoder(r, f.value, &b, l)
}

// MakePrimitiveRecord creates a record for common types.
func MakePrimitiveRecord(typ Type, val interface{}) Record {
	var (
		staticSize uint64
		sizeFunc   SizeFunc
		encoder    Encoder
		decoder    Decoder
	)
	switch e := val.(type) {
	case *uint8:
		staticSize = 1
		encoder = EUint8
		decoder = DUint8

	case *uint16:
		staticSize = 2
		encoder = EUint16
		decoder = DUint16

	case *uint32:
		staticSize = 4
		encoder = EUint32
		decoder = DUint32

	case *uint64:
		staticSize = 8
		encoder = EUint64
		decoder = DUint64

	case *[32]byte:
		staticSize = 32
		encoder = EBytes32
		decoder = DBytes32

	case *[33]byte:
		staticSize = 33
		encoder = EBytes33
		decoder = DBytes33

	case **btcec.PublicKey:
		staticSize = 33
		encoder = EPubKey
		decoder = DPubKey

	case *[64]byte:
		staticSize = 64
		encoder = EBytes64
		decoder = DBytes64

	case *[]byte:
		sizeFunc = SizeVarBytes(e)
		encoder = EVarBytes
		decoder = DVarBytes

	default:
		panic(fmt.Sprintf("unknown primitive type: %T", val))
	}

	return Record{
		value:      val,
		typ:        typ,
		staticSize: staticSize,
		sizeFunc:   sizeFunc,
		encoder:    encoder,
		decoder:    decoder,
	}
}

// MakeStaticRecord creates a record for a field of fixed-size
func MakeStaticRecord(typ Type, val interface{}, size uint64, encoder Encoder,
	decoder Decoder) Record {

	return Record{
		value:      val,
		typ:        typ,
		staticSize: size,
		encoder:    encoder,
		decoder:    decoder,
	}
}

// MakeDynamicRecord creates a record whose size may vary, and will be
// determined at the time of encoding via sizeFunc.
func MakeDynamicRecord(typ Type, val interface{}, sizeFunc SizeFunc,
	encoder Encoder, decoder Decoder) Record {

	return Record{
		value:    val,
		typ:      typ,
		sizeFunc: sizeFunc,
		encoder:  encoder,
		decoder:  decoder,
	}
}

// RecordsToMap encodes a series of TLV records as raw key-value pairs in the
// form of a map.
func RecordsToMap(records []Record) (map[uint64][]byte, error) {
	tlvMap := make(map[uint64][]byte, len(records))

	for _, record := range records {
		var b bytes.Buffer
		if err := record.Encode(&b); err != nil {
			return nil, err
		}

		tlvMap[uint64(record.Type())] = b.Bytes()
	}

	return tlvMap, nil
}

// StubEncoder is a factory function that makes a stub tlv.Encoder out of a raw
// value. We can use this to make a record that can be encoded when we don't
// actually know it's true underlying value, and only it serialization.
func StubEncoder(v []byte) Encoder {
	return func(w io.Writer, val interface{}, buf *[8]byte) error {
		_, err := w.Write(v)
		return err
	}
}

// MapToRecords encodes the passed TLV map as a series of regular tlv.Record
// instances. The resulting set of records will be returned in sorted order by
// their type.
func MapToRecords(tlvMap map[uint64][]byte) []Record {
	records := make([]Record, 0, len(tlvMap))
	for k, v := range tlvMap {
		// We don't pass in a decoder here since we don't actually know
		// the type, and only expect this Record to be used for display
		// and encoding purposes.
		record := MakeStaticRecord(
			Type(k), nil, uint64(len(v)), StubEncoder(v), nil,
		)

		records = append(records, record)
	}

	SortRecords(records)

	return records
}

// SortRecords is a helper function that will sort a slice of records in place
// according to their type.
func SortRecords(records []Record) {
	if len(records) == 0 {
		return
	}

	sort.Slice(records, func(i, j int) bool {
		return records[i].Type() < records[j].Type()
	})
}