pgxevent/pgxevent.go
2022-04-23 22:25:40 +03:00

236 lines
6.1 KiB
Go

package pgxevent
import (
"context"
"strings"
"text/template"
pgx "github.com/jackc/pgx/v4"
pgxpool "github.com/jackc/pgx/v4/pgxpool"
)
// Notification is a binding of a condition on a data table to a
// handler: when the condition is (or becomes) true for a row,
// the row ID is passed to Handler.Handle
type Listener struct {
Table string // Table name, may be double-quoted
PrimaryKeyColumn string // Primary key column
ConditionTemplate *template.Template // WHERE condition with {{.}} substituted for table alias
Handler EventHandler // Interface
}
type EventConfig struct {
DatabaseURL string
ServiceName string
ServiceVersion string
Listeners map[string]Listener
}
type request struct {
Channel string
Payload string
}
type completion struct {
Channel string
Payload string
Error error
}
func Listen(ctx context.Context, cfg *EventConfig) error {
ctx, cancel := context.WithCancel(ctx)
_ = cancel
pool, err := pgxpool.Connect(ctx, cfg.DatabaseURL)
if err != nil {
return err
}
defer pool.Close()
conn, err := pool.Acquire(ctx)
defer conn.Release()
err = conn.BeginFunc(ctx, func(tx pgx.Tx) error {
_, err = tx.Exec(ctx, `CREATE TABLE IF NOT EXISTS pgxevent
(service TEXT PRIMARY KEY, version TEXT)`)
if err != nil {
return err
}
row := tx.QueryRow(ctx, `SELECT version FROM pgxevent
WHERE service = $1`, cfg.ServiceName)
needsUpdate := false
ver := ""
err := row.Scan(&ver)
if err == pgx.ErrNoRows {
_, err := tx.Exec(ctx, `INSERT INTO
pgxevent(service, version) VALUES ($1,$2)`,
cfg.ServiceName, cfg.ServiceVersion)
if err != nil {
return err
}
needsUpdate = true
} else if err == nil {
if err != nil {
return err
}
needsUpdate = (ver != cfg.ServiceVersion)
_, err := tx.Exec(ctx, `UPDATE pgxevent
SET version=$1 WHERE service=$2`,
cfg.ServiceVersion, cfg.ServiceName)
if err != nil {
return err
}
} else {
return err
}
if needsUpdate {
for key, value := range cfg.Listeners {
triggerName := "pgxevent:" + cfg.ServiceName + ":" + key
condBuffer := strings.Builder{}
err = value.ConditionTemplate.Execute(&condBuffer, "new")
if err != nil {
return err
}
conditionNew := condBuffer.String()
condBuffer = strings.Builder{}
err = value.ConditionTemplate.Execute(&condBuffer, "old")
if err != nil {
return err
}
conditionOld := condBuffer.String()
condition := "(" + conditionNew + ") and (old is null or not (" + conditionOld + "))"
_, err := tx.Exec(ctx,
`CREATE OR REPLACE FUNCTION "`+triggerName+`"()
RETURNS TRIGGER AS $$
DECLARE
BEGIN
IF (`+condition+`) THEN
PERFORM pg_notify('`+triggerName+`', CAST(NEW.`+value.PrimaryKeyColumn+` AS TEXT));
END IF;
RETURN NEW;
END $$ LANGUAGE PLPGSQL
`)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `DROP TRIGGER IF EXISTS "`+
triggerName+`" ON `+value.Table)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `CREATE TRIGGER "`+triggerName+`"`+`
AFTER INSERT OR UPDATE ON `+value.Table+`
FOR EACH ROW
EXECUTE PROCEDURE "`+triggerName+`"()
`)
if err != nil {
return err
}
}
}
return nil
})
if err != nil {
return err
}
running := make(map[string]map[string]bool)
for key, _ := range cfg.Listeners {
triggerName := "pgxevent:" + cfg.ServiceName + ":" + key
_, err := conn.Exec(ctx, `LISTEN "`+triggerName+`"`)
if err != nil {
return err
}
running[key] = make(map[string]bool)
}
var cascadingError error
startChannel := make(chan request)
completionChannel := make(chan completion)
counter := 2
teardown := false
go func() {
prefix := "pgxevent:" + cfg.ServiceName + ":"
for {
notification, err := conn.Conn().WaitForNotification(ctx)
if err != nil {
completionChannel <- completion{"", "", err}
return
}
startChannel <- request{notification.Channel[len(prefix):],
notification.Payload}
}
}()
go func() {
for channel, notification := range cfg.Listeners {
condBuffer := strings.Builder{}
err := notification.ConditionTemplate.Execute(&condBuffer, "Tab")
if err != nil {
completionChannel <- completion{"", "", err}
return
}
rows, err := pool.Query(ctx, `SELECT CAST ( Tab.`+
notification.PrimaryKeyColumn+` AS text) as id FROM "`+
notification.Table+`" AS Tab WHERE`+
`(`+condBuffer.String()+`)`)
if err != nil {
completionChannel <- completion{"", "", err}
return
}
for rows.Next() {
var payload string
err = rows.Scan(&payload)
if err != nil {
completionChannel <- completion{"", "", err}
return
}
startChannel <- request{channel, payload}
}
}
completionChannel <- completion{}
return
}()
for {
select {
case start := <-startChannel:
if !teardown {
if !running[start.Channel][start.Payload] {
counter++
running[start.Channel][start.Payload] = true
go func(req request) {
err := cfg.Listeners[start.Channel].
Handler.Handle(ctx, pool, req.Payload)
completionChannel <- completion{req.Channel, req.Payload, err}
}(start)
}
}
case complete := <-completionChannel:
if complete.Error != nil {
if !teardown {
teardown = true
cancel()
cascadingError = complete.Error
}
}
if complete.Channel != "" {
delete(running[complete.Channel], complete.Payload)
}
counter--
if counter == 0 {
return cascadingError
}
}
}
}
type EventHandler interface {
Handle(ctx context.Context, pool *pgxpool.Pool, id string) error
}