222 lines
4.9 KiB
Go
222 lines
4.9 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"nhooyr.io/websocket"
|
|
)
|
|
|
|
var (
|
|
Listen = flag.String("listen", "0.0.0.0:8080", "listen interface:port")
|
|
Peers = flag.String("peers", "peers.json", "peer definition file")
|
|
Verbose = flag.Bool("verbose", false, "verbose logging")
|
|
)
|
|
|
|
type Pair struct {
|
|
P string `json:"p"`
|
|
U string `json:"u"`
|
|
}
|
|
|
|
type Intro struct {
|
|
Location int
|
|
Conn *websocket.Conn
|
|
Addr string
|
|
}
|
|
|
|
type Msg struct {
|
|
Location int
|
|
Bytes []byte
|
|
Type websocket.MessageType
|
|
}
|
|
|
|
type Handler struct {
|
|
pairs []Pair
|
|
locations map[string]int
|
|
conns []*websocket.Conn
|
|
addrs []string
|
|
closeReqs chan int
|
|
intros chan Intro
|
|
messages chan Msg
|
|
}
|
|
|
|
func NewHandler() *Handler {
|
|
h := &Handler{}
|
|
h.init()
|
|
return h
|
|
}
|
|
|
|
func (h *Handler) init() {
|
|
h.closeReqs = make(chan int)
|
|
h.intros = make(chan Intro)
|
|
h.messages = make(chan Msg)
|
|
go h.monitorChannels()
|
|
}
|
|
|
|
func (h *Handler) peer(location int) int {
|
|
if location%2 == 0 {
|
|
return location + 1
|
|
} else {
|
|
return location - 1
|
|
}
|
|
}
|
|
|
|
func (h *Handler) monitorChannels() {
|
|
for {
|
|
select {
|
|
case intro := <-h.intros:
|
|
old := h.conns[intro.Location]
|
|
if old != nil {
|
|
log.WithField("old", h.addrs[intro.Location]).
|
|
WithField("new", intro.Addr).
|
|
WithField("uuid", h.uuidOf(intro.Location)).Info("duplicate user")
|
|
intro.Conn.Close(websocket.StatusNormalClosure, "duplicate user")
|
|
break
|
|
}
|
|
h.conns[intro.Location] = intro.Conn
|
|
h.addrs[intro.Location] = intro.Addr
|
|
log.WithField("client", intro.Addr).
|
|
WithField("uuid", h.uuidOf(intro.Location)).
|
|
Debug("drained")
|
|
go h.relay(intro.Conn, intro.Location)
|
|
case msg := <-h.messages:
|
|
peerLocation := h.peer(msg.Location)
|
|
peer := h.conns[peerLocation]
|
|
log.WithField("client", h.addrs[msg.Location]).
|
|
WithField("uuid", h.uuidOf(msg.Location)).
|
|
WithField("message", string(msg.Bytes)).Debug("received")
|
|
if peer != nil {
|
|
log.WithField("client", h.addrs[peerLocation]).
|
|
WithField("uuid", h.uuidOf(peerLocation)).
|
|
Debug("sent")
|
|
err := peer.Write(context.Background(),
|
|
msg.Type, msg.Bytes)
|
|
if err != nil {
|
|
err := peer.Close(websocket.StatusNormalClosure, "write error")
|
|
if err != nil {
|
|
log.Warning(err)
|
|
}
|
|
h.conns[peerLocation] = nil
|
|
h.addrs[peerLocation] = ""
|
|
}
|
|
}
|
|
case closeLoc := <-h.closeReqs:
|
|
err := h.conns[closeLoc].
|
|
Close(websocket.StatusNormalClosure, "")
|
|
if err != nil {
|
|
log.Warning(err)
|
|
}
|
|
|
|
h.conns[closeLoc] = nil
|
|
h.addrs[closeLoc] = ""
|
|
log.Trace(h.conns)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Handler) uuidOf(location int) string {
|
|
if location%2 == 0 {
|
|
return h.pairs[location/2].P
|
|
} else {
|
|
return h.pairs[location/2].U
|
|
}
|
|
}
|
|
|
|
func (h *Handler) relay(from *websocket.Conn, location int) {
|
|
for {
|
|
msgType, msgBytes, err := from.Read(context.Background())
|
|
if err != nil {
|
|
log.Debug("sending close message ", location)
|
|
h.closeReqs <- location
|
|
log.Debug("returning from Relay")
|
|
return
|
|
}
|
|
h.messages <- Msg{
|
|
Location: location,
|
|
Type: msgType,
|
|
Bytes: msgBytes,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Handler) LoadPeers(filename string) error {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
decoder := json.NewDecoder(file)
|
|
err = decoder.Decode(&h.pairs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
h.locations = map[string]int{}
|
|
for i, pair := range h.pairs {
|
|
if _, ok := h.locations[pair.P]; ok {
|
|
return fmt.Errorf("duplicate uuid %v", pair.P)
|
|
}
|
|
if _, ok := h.locations[pair.U]; ok {
|
|
return fmt.Errorf("duplicate uuid %v", pair.U)
|
|
}
|
|
h.locations[pair.P] = i * 2
|
|
h.locations[pair.U] = i*2 + 1
|
|
}
|
|
h.conns = make([]*websocket.Conn, len(h.pairs)*2)
|
|
h.addrs = make([]string, len(h.pairs)*2)
|
|
return nil
|
|
}
|
|
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
wsConn, err := websocket.Accept(w, r, nil)
|
|
if err != nil {
|
|
log.WithField("client", r.RemoteAddr).Error(err)
|
|
return
|
|
}
|
|
wsConn.SetReadLimit(-1)
|
|
go h.ServeWebSocket(wsConn, r.RemoteAddr)
|
|
}
|
|
|
|
func (h *Handler) ServeWebSocket(wsConn *websocket.Conn, rAddr string) {
|
|
authLog := log.WithField("client", rAddr)
|
|
authLog.Trace("waiting for auth")
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
_, msgBytes, err := wsConn.Read(ctx)
|
|
if err != nil {
|
|
authLog.Error(err)
|
|
wsConn.Close(websocket.StatusNormalClosure, "")
|
|
return
|
|
}
|
|
uuid := string(msgBytes)
|
|
loc, ok := h.locations[uuid]
|
|
if !ok {
|
|
authLog.WithField("uuid", uuid).Error("unknown user")
|
|
wsConn.CloseNow()
|
|
return
|
|
} else {
|
|
authLog.WithField("uuid", uuid).Info("logged in")
|
|
}
|
|
h.intros <- Intro{loc, wsConn, rAddr}
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
if *Verbose {
|
|
log.SetLevel(log.TraceLevel)
|
|
log.Trace("verbose logging")
|
|
}
|
|
h := NewHandler()
|
|
log.Info("Parsing peers from ", *Peers)
|
|
err := h.LoadPeers(*Peers)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
log.Info("Starting relay on ", *Listen)
|
|
log.Fatal(http.ListenAndServe(*Listen, h))
|
|
}
|