Added Error message type to wire protocol

This commit is contained in:
Joseph Poon 2016-01-05 08:53:42 -08:00 committed by Olaoluwa Osuntokun
parent f3849f5c10
commit b4c644c99a
5 changed files with 144 additions and 6 deletions

74
lnwire/error_generic.go Normal file

@ -0,0 +1,74 @@
package lnwire
import (
"fmt"
"io"
)
//Multiple Clearing Requests are possible by putting this inside an array of
//clearing requests
type ErrorGeneric struct {
//We can use a different data type for this if necessary...
ChannelID uint64
//Some kind of message
//Max length 8192
Problem string
}
func (c *ErrorGeneric) Decode(r io.Reader, pver uint32) error {
//ChannelID(8)
//Problem
err := readElements(r,
&c.ChannelID,
&c.Problem,
)
if err != nil {
return err
}
return nil
}
//Creates a new ErrorGeneric
func NewErrorGeneric() *ErrorGeneric {
return &ErrorGeneric{}
}
//Serializes the item from the ErrorGeneric struct
//Writes the data to w
func (c *ErrorGeneric) Encode(w io.Writer, pver uint32) error {
err := writeElements(w,
c.ChannelID,
c.Problem,
)
if err != nil {
return err
}
return nil
}
func (c *ErrorGeneric) Command() uint32 {
return CmdErrorGeneric
}
func (c *ErrorGeneric) MaxPayloadLength(uint32) uint32 {
//8+8192
return 8208
}
//Makes sure the struct data is valid (e.g. no negatives or invalid pkscripts)
func (c *ErrorGeneric) Validate() error {
if len(c.Problem) > 8192 {
return fmt.Errorf("Problem string length too long")
}
//We're good!
return nil
}
func (c *ErrorGeneric) String() string {
return fmt.Sprintf("\n--- Begin ErrorGeneric ---\n") +
fmt.Sprintf("ChannelID:\t%d\n", c.ChannelID) +
fmt.Sprintf("Problem:\t%s\n", c.Problem) +
fmt.Sprintf("--- End ErrorGeneric ---\n")
}

@ -0,0 +1,32 @@
package lnwire
import (
"testing"
)
var (
errorGeneric = &ErrorGeneric{
ChannelID: uint64(12345678),
Problem: "Hello world!",
}
errorGenericSerializedString = "0000000000bc614e000c48656c6c6f20776f726c6421"
errorGenericSerializedMessage = "0709110b00000fa0000000160000000000bc614e000c48656c6c6f20776f726c6421"
)
func TestErrorGenericEncodeDecode(t *testing.T) {
//All of these types being passed are of the message interface type
//Test serialization, runs: message.Encode(b, 0)
//Returns bytes
//Compares the expected serialized string from the original
s := SerializeTest(t, errorGeneric, errorGenericSerializedString, filename)
//Test deserialization, runs: message.Decode(s, 0)
//Makes sure the deserialized struct is the same as the original
newMessage := NewErrorGeneric()
DeserializeTest(t, s, newMessage, errorGeneric)
//Test message using Message interface
//Serializes into buf: WriteMessage(buf, message, uint32(1), wire.TestNet3)
//Deserializes into msg: _, msg, _ , err := ReadMessage(buf, uint32(1), wire.TestNet3)
MessageSerializeDeserializeTest(t, errorGeneric, errorGenericSerializedMessage)
}

@ -18,12 +18,6 @@ type HTLCSettleAccept struct {
func (c *HTLCSettleAccept) Decode(r io.Reader, pver uint32) error { func (c *HTLCSettleAccept) Decode(r io.Reader, pver uint32) error {
//ChannelID(8) //ChannelID(8)
//StagingID(8) //StagingID(8)
//Expiry(4)
//Amount(4)
//NextHop(20)
//ContractType(1)
//RedemptionHashes (numOfHashes * 20 + numOfHashes)
//Blob(2+blobsize)
err := readElements(r, err := readElements(r,
&c.ChannelID, &c.ChannelID,
&c.StagingID, &c.StagingID,

@ -202,6 +202,21 @@ func writeElement(w io.Writer, element interface{}) error {
return err return err
} }
return nil return nil
case string:
strlen := len(e)
if strlen > 65535 {
return fmt.Errorf("String too long!")
}
//Write the size (2-bytes)
err = writeElement(w, uint16(strlen))
if err != nil {
return err
}
//Write the data
_, err = w.Write([]byte(e))
if err != nil {
return err
}
case []*wire.TxIn: case []*wire.TxIn:
//Append the unsigned(!!!) txins //Append the unsigned(!!!) txins
//Write the size (1-byte) //Write the size (1-byte)
@ -453,6 +468,24 @@ func readElement(r io.Reader, element interface{}) error {
return fmt.Errorf("EOF: Signature length mismatch.") return fmt.Errorf("EOF: Signature length mismatch.")
} }
return nil return nil
case *string:
//Get the string length first
var strlen uint16
err = readElement(r, &strlen)
if err != nil {
return err
}
//Read the string for the length
l := io.LimitReader(r, int64(strlen))
b, err := ioutil.ReadAll(l)
if len(b) != int(strlen) {
return fmt.Errorf("EOF: String length mismatch.")
}
*e = string(b)
if err != nil {
return err
}
return nil
case *[]*wire.TxIn: case *[]*wire.TxIn:
//Read the size (1-byte number of txins) //Read the size (1-byte number of txins)
var numScripts uint8 var numScripts uint8

@ -57,6 +57,9 @@ const (
//Commitments //Commitments
CmdCommitSignature = uint32(2000) CmdCommitSignature = uint32(2000)
CmdCommitRevocation = uint32(2010) CmdCommitRevocation = uint32(2010)
//Error
CmdErrorGeneric = uint32(4000)
) )
//Every message has these functions: //Every message has these functions:
@ -103,6 +106,8 @@ func makeEmptyMessage(command uint32) (Message, error) {
msg = &CommitSignature{} msg = &CommitSignature{}
case CmdCommitRevocation: case CmdCommitRevocation:
msg = &CommitRevocation{} msg = &CommitRevocation{}
case CmdErrorGeneric:
msg = &ErrorGeneric{}
default: default:
return nil, fmt.Errorf("unhandled command [%d]", command) return nil, fmt.Errorf("unhandled command [%d]", command)
} }