// Code derived from https:// github.com/btcsuite/btcd/blob/master/wire/message.go
package lnwire

import (
	"bytes"
	"fmt"
	"io"

	"github.com/btcsuite/btcd/wire"
)

//  message type identifyer bytes
const (
	MSGID_FUNDREQUEST  = 0x30
	MSGID_FUNDRESPONSE = 0x31

	MSGID_CLOSEREQUEST  = 0x40
	MSGID_CLOSERESPONSE = 0x41

	MSGID_TEXTCHAT = 0x70

	MSGID_FWDMSG     = 0x20
	MSGID_FWDAUTHREQ = 0x21
)

// 4-byte network + 4-byte message id + payload-length 4-byte
const MessageHeaderSize = 12

const MaxMessagePayload = 1024 * 1024 * 32 //  32MB

const (
	// Funding channel open
	CmdFundingRequest      = uint32(200)
	CmdFundingResponse     = uint32(210)
	CmdFundingSignAccept   = uint32(220)
	CmdFundingSignComplete = uint32(230)

	// Close channel
	CmdCloseRequest  = uint32(300)
	CmdCloseComplete = uint32(310)

	// TODO Renumber to 1100
	// HTLC payment
	CmdHTLCAddRequest = uint32(1000)
	CmdHTLCAddAccept  = uint32(1010)
	CmdHTLCAddReject  = uint32(1020)

	// TODO Renumber to 1200
	// HTLC settlement
	CmdHTLCSettleRequest = uint32(1100)
	CmdHTLCSettleAccept  = uint32(1110)

	// HTLC timeout
	CmdHTLCTimeoutRequest = uint32(1300)
	CmdHTLCTimeoutAccept  = uint32(1310)

	// Commitments
	CmdCommitSignature  = uint32(2000)
	CmdCommitRevocation = uint32(2010)

	// Error
	CmdErrorGeneric = uint32(4000)
)

// Every message has these functions:
type Message interface {
	Decode(io.Reader, uint32) error // (io, protocol version)
	Encode(io.Writer, uint32) error // (io, protocol version)
	Command() uint32                // returns ID of the message
	MaxPayloadLength(uint32) uint32 // (version) maxpayloadsize
	Validate() error                // Validates the data struct
	String() string
}

func makeEmptyMessage(command uint32) (Message, error) {
	var msg Message

	switch command {
	case CmdFundingRequest:
		msg = &FundingRequest{}
	case CmdFundingResponse:
		msg = &FundingResponse{}
	case CmdFundingSignAccept:
		msg = &FundingSignAccept{}
	case CmdFundingSignComplete:
		msg = &FundingSignComplete{}
	case CmdCloseRequest:
		msg = &CloseRequest{}
	case CmdCloseComplete:
		msg = &CloseComplete{}
	case CmdHTLCAddRequest:
		msg = &HTLCAddRequest{}
	case CmdHTLCAddAccept:
		msg = &HTLCAddAccept{}
	case CmdHTLCAddReject:
		msg = &HTLCAddReject{}
	case CmdHTLCSettleRequest:
		msg = &HTLCSettleRequest{}
	case CmdHTLCSettleAccept:
		msg = &HTLCSettleAccept{}
	case CmdHTLCTimeoutRequest:
		msg = &HTLCTimeoutRequest{}
	case CmdHTLCTimeoutAccept:
		msg = &HTLCTimeoutAccept{}
	case CmdCommitSignature:
		msg = &CommitSignature{}
	case CmdCommitRevocation:
		msg = &CommitRevocation{}
	case CmdErrorGeneric:
		msg = &ErrorGeneric{}
	default:
		return nil, fmt.Errorf("unhandled command [%d]", command)
	}

	return msg, nil
}

type messageHeader struct {
	// NOTE(j): We don't need to worry about the magic overlapping with
	// bitcoin since this is inside encrypted comms anyway, but maybe we
	// should use the XOR (^wire.TestNet3) just in case???
	magic   wire.BitcoinNet // which Blockchain Technology(TM) to use
	command uint32
	length  uint32
}

func readMessageHeader(r io.Reader) (int, *messageHeader, error) {
	var headerBytes [MessageHeaderSize]byte
	n, err := io.ReadFull(r, headerBytes[:])
	if err != nil {
		return n, nil, err
	}
	hr := bytes.NewReader(headerBytes[:])

	hdr := messageHeader{}

	err = readElements(hr,
		&hdr.magic,
		&hdr.command,
		&hdr.length)
	if err != nil {
		return n, nil, err
	}

	return n, &hdr, nil
}

//  discardInput reads n bytes from reader r in chunks and discards the read
//  bytes.  This is used to skip payloads when various errors occur and helps
//  prevent rogue nodes from causing massive memory allocation through forging
//  header length.
func discardInput(r io.Reader, n uint32) {
	maxSize := uint32(10 * 1024) // 10k at a time
	numReads := n / maxSize
	bytesRemaining := n % maxSize
	if n > 0 {
		buf := make([]byte, maxSize)
		for i := uint32(0); i < numReads; i++ {
			io.ReadFull(r, buf)
		}
	}
	if bytesRemaining > 0 {
		buf := make([]byte, bytesRemaining)
		io.ReadFull(r, buf)
	}
}

func WriteMessage(w io.Writer, msg Message, pver uint32, btcnet wire.BitcoinNet) (int, error) {
	totalBytes := 0

	cmd := msg.Command()

	// Encode the message payload
	var bw bytes.Buffer
	err := msg.Encode(&bw, pver)
	if err != nil {
		return totalBytes, err
	}
	payload := bw.Bytes()
	lenp := len(payload)

	// Enforce maximum overall message payload
	if lenp > MaxMessagePayload {
		return totalBytes, fmt.Errorf("message payload is too large - encoded %d bytes, but maximum message payload is %d bytes", lenp, MaxMessagePayload)
	}

	// Enforce maximum message payload on the message type
	mpl := msg.MaxPayloadLength(pver)
	if uint32(lenp) > mpl {
		return totalBytes, fmt.Errorf("message payload is too large - encoded %d bytes, but maximum message payload of type %x is %d bytes", lenp, cmd, mpl)
	}

	// Create header for the message
	hdr := messageHeader{}
	hdr.magic = btcnet
	hdr.command = cmd
	hdr.length = uint32(lenp)

	//  Encode the header for the message.  This is done to a buffer
	//  rather than directly to the writer since writeElements doesn't
	//  return the number of bytes written.
	hw := bytes.NewBuffer(make([]byte, 0, MessageHeaderSize))
	writeElements(hw, hdr.magic, hdr.command, hdr.length)

	// Write header
	n, err := w.Write(hw.Bytes())
	totalBytes += n
	if err != nil {
		return totalBytes, err
	}

	// Write payload
	n, err = w.Write(payload)
	totalBytes += n
	if err != nil {
		return totalBytes, err
	}

	return totalBytes, nil
}

func ReadMessage(r io.Reader, pver uint32, btcnet wire.BitcoinNet) (int, Message, []byte, error) {
	totalBytes := 0
	n, hdr, err := readMessageHeader(r)
	totalBytes += n
	if err != nil {
		return totalBytes, nil, nil, err
	}

	// Enforce maximum message payload
	if hdr.length > MaxMessagePayload {
		return totalBytes, nil, nil, fmt.Errorf("message payload is too large - header indicates %d bytes, but max message payload is %d bytes.", hdr.length, MaxMessagePayload)
	}

	// Check for messages in the wrong bitcoin network
	if hdr.magic != btcnet {
		discardInput(r, hdr.length)
		return totalBytes, nil, nil, fmt.Errorf("message from other network [%v]", hdr.magic)
	}

	// Create struct of appropriate message type based on the command
	command := hdr.command
	msg, err := makeEmptyMessage(command)
	if err != nil {
		discardInput(r, hdr.length)
		return totalBytes, nil, nil, fmt.Errorf("ReadMessage %s", err.Error())
	}

	// Check for maximum length based on the message type
	mpl := msg.MaxPayloadLength(pver)
	if hdr.length > mpl {
		discardInput(r, hdr.length)
		return totalBytes, nil, nil, fmt.Errorf("payload exceeds max length. indicates %v bytes, but max of message type %v is %v.", hdr.length, command, mpl)
	}

	// Read payload
	payload := make([]byte, hdr.length)
	n, err = io.ReadFull(r, payload)
	totalBytes += n
	if err != nil {
		return totalBytes, nil, nil, err
	}

	// Unmarshal message
	pr := bytes.NewBuffer(payload)
	err = msg.Decode(pr, pver)
	if err != nil {
		return totalBytes, nil, nil, err
	}

	// Validate the data
	err = msg.Validate()
	if err != nil {
		return totalBytes, nil, nil, err
	}

	// We're good!
	return totalBytes, msg, payload, nil
}