package pgxevent import ( "context" "strings" "text/template" pgx "github.com/jackc/pgx/v4" pgxpool "github.com/jackc/pgx/v4/pgxpool" ) // Notification is a binding of a condition on a data table to a // handler: when the condition is (or becomes) true for a row, // the row ID is passed to Handler.Handle type Listener struct { Table string // Table name, may be double-quoted PrimaryKeyColumn string // Primary key column ConditionTemplate *template.Template // WHERE condition with {{.}} substituted for table alias Handler EventHandler // Interface } type EventConfig struct { DatabaseURL string ServiceName string ServiceVersion string Listeners map[string]Listener } type request struct { Channel string Payload string } type completion struct { Channel string Payload string Error error } func Listen(ctx context.Context, cfg *EventConfig) error { ctx, cancel := context.WithCancel(ctx) _ = cancel pool, err := pgxpool.Connect(ctx, cfg.DatabaseURL) if err != nil { return err } defer pool.Close() conn, err := pool.Acquire(ctx) defer conn.Release() err = conn.BeginFunc(ctx, func(tx pgx.Tx) error { _, err = tx.Exec(ctx, `CREATE TABLE IF NOT EXISTS pgxevent (service TEXT PRIMARY KEY, version TEXT)`) if err != nil { return err } row := tx.QueryRow(ctx, `SELECT version FROM pgxevent WHERE service = $1`, cfg.ServiceName) needsUpdate := false ver := "" err := row.Scan(&ver) if err == pgx.ErrNoRows { _, err := tx.Exec(ctx, `INSERT INTO pgxevent(service, version) VALUES ($1,$2)`, cfg.ServiceName, cfg.ServiceVersion) if err != nil { return err } needsUpdate = true } else if err == nil { if err != nil { return err } needsUpdate = (ver != cfg.ServiceVersion) _, err := tx.Exec(ctx, `UPDATE pgxevent SET version=$1 WHERE service=$2`, cfg.ServiceVersion, cfg.ServiceName) if err != nil { return err } } else { return err } if needsUpdate { for key, value := range cfg.Listeners { triggerName := "pgxevent:" + cfg.ServiceName + ":" + key condBuffer := strings.Builder{} err = value.ConditionTemplate.Execute(&condBuffer, "new") if err != nil { return err } conditionNew := condBuffer.String() condBuffer = strings.Builder{} err = value.ConditionTemplate.Execute(&condBuffer, "old") if err != nil { return err } conditionOld := condBuffer.String() condition := "(" + conditionNew + ") and (old is null or not (" + conditionOld + "))" _, err := tx.Exec(ctx, `CREATE OR REPLACE FUNCTION "`+triggerName+`"() RETURNS TRIGGER AS $$ DECLARE BEGIN IF (`+condition+`) THEN PERFORM pg_notify('`+triggerName+`', CAST(NEW.`+value.PrimaryKeyColumn+` AS TEXT)); END IF; RETURN NEW; END $$ LANGUAGE PLPGSQL `) if err != nil { return err } _, err = tx.Exec(ctx, `DROP TRIGGER IF EXISTS "`+ triggerName+`" ON `+value.Table) if err != nil { return err } _, err = tx.Exec(ctx, `CREATE TRIGGER "`+triggerName+`"`+` AFTER INSERT OR UPDATE ON `+value.Table+` FOR EACH ROW EXECUTE PROCEDURE "`+triggerName+`"() `) if err != nil { return err } } } return nil }) if err != nil { return err } running := make(map[string]map[string]bool) for key, _ := range cfg.Listeners { triggerName := "pgxevent:" + cfg.ServiceName + ":" + key _, err := conn.Exec(ctx, `LISTEN "`+triggerName+`"`) if err != nil { return err } running[key] = make(map[string]bool) } var cascadingError error startChannel := make(chan request) completionChannel := make(chan completion) counter := 2 teardown := false go func() { prefix := "pgxevent:" + cfg.ServiceName + ":" for { notification, err := conn.Conn().WaitForNotification(ctx) if err != nil { completionChannel <- completion{"", "", err} return } startChannel <- request{notification.Channel[len(prefix):], notification.Payload} } }() go func() { for channel, notification := range cfg.Listeners { condBuffer := strings.Builder{} err := notification.ConditionTemplate.Execute(&condBuffer, "Tab") if err != nil { completionChannel <- completion{"", "", err} return } rows, err := pool.Query(ctx, `SELECT CAST ( Tab.`+ notification.PrimaryKeyColumn+` AS text) as id FROM "`+ notification.Table+`" AS Tab WHERE`+ `(`+condBuffer.String()+`)`) if err != nil { completionChannel <- completion{"", "", err} return } for rows.Next() { var payload string err = rows.Scan(&payload) if err != nil { completionChannel <- completion{"", "", err} return } startChannel <- request{channel, payload} } } completionChannel <- completion{} return }() for { select { case start := <-startChannel: if !teardown { if !running[start.Channel][start.Payload] { counter++ running[start.Channel][start.Payload] = true go func(req request) { err := cfg.Listeners[start.Channel]. Handler.Handle(ctx, pool, req.Payload) completionChannel <- completion{req.Channel, req.Payload, err} }(start) } } case complete := <-completionChannel: if complete.Error != nil { if !teardown { teardown = true cancel() cascadingError = complete.Error } } if complete.Channel != "" { delete(running[complete.Channel], complete.Payload) } counter-- if counter == 0 { return cascadingError } } } } type EventHandler interface { Handle(ctx context.Context, pool *pgxpool.Pool, id string) error }