package main import ( "context" "encoding/json" "flag" "fmt" "net/http" "os" "time" log "github.com/sirupsen/logrus" "nhooyr.io/websocket" ) var ( Listen = flag.String("listen", "0.0.0.0:8080", "listen interface:port") Peers = flag.String("peers", "peers.json", "peer definition file") Verbose = flag.Bool("verbose", false, "verbose logging") ) type Pair struct { P string `json:"p"` U string `json:"u"` } type Intro struct { Location int Conn *websocket.Conn Addr string } type Msg struct { Location int Bytes []byte Type websocket.MessageType } type Handler struct { pairs []Pair locations map[string]int conns []*websocket.Conn addrs []string closeReqs chan int intros chan Intro messages chan Msg } func NewHandler() *Handler { h := &Handler{} h.init() return h } func (h *Handler) init() { h.closeReqs = make(chan int) h.intros = make(chan Intro) h.messages = make(chan Msg) go h.monitorChannels() } func (h *Handler) peer(location int) int { if location%2 == 0 { return location + 1 } else { return location - 1 } } func (h *Handler) monitorChannels() { for { select { case intro := <-h.intros: old := h.conns[intro.Location] if old != nil { log.WithField("old", h.addrs[intro.Location]). WithField("new", intro.Addr). WithField("uuid", h.uuidOf(intro.Location)).Info("duplicate user") intro.Conn.Close(websocket.StatusNormalClosure, "duplicate user") break } h.conns[intro.Location] = intro.Conn h.addrs[intro.Location] = intro.Addr log.WithField("client", intro.Addr). WithField("uuid", h.uuidOf(intro.Location)). Debug("drained") go h.relay(intro.Conn, intro.Location) case msg := <-h.messages: peerLocation := h.peer(msg.Location) peer := h.conns[peerLocation] log.WithField("client", h.addrs[msg.Location]). WithField("uuid", h.uuidOf(msg.Location)). WithField("message", string(msg.Bytes)).Debug("received") if peer != nil { log.WithField("client", h.addrs[peerLocation]). WithField("uuid", h.uuidOf(peerLocation)). Debug("sent") err := peer.Write(context.Background(), msg.Type, msg.Bytes) if err != nil { err := peer.Close(websocket.StatusNormalClosure, "write error") if err != nil { log.Warning(err) } h.conns[peerLocation] = nil h.addrs[peerLocation] = "" } } case closeLoc := <-h.closeReqs: err := h.conns[closeLoc]. Close(websocket.StatusNormalClosure, "") if err != nil { log.Warning(err) } h.conns[closeLoc] = nil h.addrs[closeLoc] = "" log.Trace(h.conns) } } } func (h *Handler) uuidOf(location int) string { if location%2 == 0 { return h.pairs[location/2].P } else { return h.pairs[location/2].U } } func (h *Handler) relay(from *websocket.Conn, location int) { for { msgType, msgBytes, err := from.Read(context.Background()) if err != nil { log.Debug("sending close message ", location) h.closeReqs <- location log.Debug("returning from Relay") return } h.messages <- Msg{ Location: location, Type: msgType, Bytes: msgBytes, } } } func (h *Handler) LoadPeers(filename string) error { file, err := os.Open(filename) if err != nil { return err } decoder := json.NewDecoder(file) err = decoder.Decode(&h.pairs) if err != nil { return err } h.locations = map[string]int{} for i, pair := range h.pairs { if _, ok := h.locations[pair.P]; ok { return fmt.Errorf("duplicate uuid %v", pair.P) } if _, ok := h.locations[pair.U]; ok { return fmt.Errorf("duplicate uuid %v", pair.U) } h.locations[pair.P] = i * 2 h.locations[pair.U] = i*2 + 1 } h.conns = make([]*websocket.Conn, len(h.pairs)*2) h.addrs = make([]string, len(h.pairs)*2) return nil } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { wsConn, err := websocket.Accept(w, r, nil) if err != nil { log.WithField("client", r.RemoteAddr).Error(err) return } wsConn.SetReadLimit(-1) go h.ServeWebSocket(wsConn, r.RemoteAddr) } func (h *Handler) ServeWebSocket(wsConn *websocket.Conn, rAddr string) { authLog := log.WithField("client", rAddr) authLog.Trace("waiting for auth") ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() _, msgBytes, err := wsConn.Read(ctx) if err != nil { authLog.Error(err) wsConn.Close(websocket.StatusNormalClosure, "") return } uuid := string(msgBytes) loc, ok := h.locations[uuid] if !ok { authLog.WithField("uuid", uuid).Error("unknown user") wsConn.CloseNow() return } else { authLog.WithField("uuid", uuid).Info("logged in") } h.intros <- Intro{loc, wsConn, rAddr} } func main() { flag.Parse() if *Verbose { log.SetLevel(log.TraceLevel) log.Trace("verbose logging") } h := NewHandler() log.Info("Parsing peers from ", *Peers) err := h.LoadPeers(*Peers) if err != nil { log.Fatal(err) } log.Info("Starting relay on ", *Listen) log.Fatal(http.ListenAndServe(*Listen, h)) }