236 lines
6.1 KiB
Go
236 lines
6.1 KiB
Go
|
package pgxevent
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"strings"
|
||
|
"text/template"
|
||
|
|
||
|
pgx "github.com/jackc/pgx/v4"
|
||
|
pgxpool "github.com/jackc/pgx/v4/pgxpool"
|
||
|
)
|
||
|
|
||
|
// Listener binds a condition on a data table to a handler: when the
|
||
|
// condition is (or becomes) true for a row, the row ID is passed to
|
||
|
// Handler.Handle
|
||
|
type Listener struct {
|
||
|
Table string // Table name, may be double-quoted
|
||
|
PrimaryKeyColumn string // Primary key column
|
||
|
ConditionTemplate *template.Template // WHERE condition with {{.}} substituted for table alias
|
||
|
Handler EventHandler // Interface
|
||
|
}
|
||
|
|
||
|
type EventConfig struct {
|
||
|
DatabaseURL string
|
||
|
ServiceName string
|
||
|
ServiceVersion string
|
||
|
Listeners map[string]Listener
|
||
|
}
|
||
|
|
||
|
type request struct {
|
||
|
Channel string
|
||
|
Payload string
|
||
|
}
|
||
|
|
||
|
type completion struct {
|
||
|
Channel string
|
||
|
Payload string
|
||
|
Error error
|
||
|
}
|
||
|
|
||
|
func Listen(ctx context.Context, cfg *EventConfig) error {
|
||
|
ctx, cancel := context.WithCancel(ctx)
|
||
|
_ = cancel
|
||
|
pool, err := pgxpool.Connect(ctx, cfg.DatabaseURL)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
defer pool.Close()
|
||
|
conn, err := pool.Acquire(ctx)
|
||
|
defer conn.Release()
|
||
|
|
||
|
err = conn.BeginFunc(ctx, func(tx pgx.Tx) error {
|
||
|
_, err = tx.Exec(ctx, `CREATE TABLE IF NOT EXISTS pgxevent
|
||
|
(service TEXT PRIMARY KEY, version TEXT)`)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
row := tx.QueryRow(ctx, `SELECT version FROM pgxevent
|
||
|
WHERE service = $1`, cfg.ServiceName)
|
||
|
needsUpdate := false
|
||
|
ver := ""
|
||
|
err := row.Scan(&ver)
|
||
|
if err == pgx.ErrNoRows {
|
||
|
_, err := tx.Exec(ctx, `INSERT INTO
|
||
|
pgxevent(service, version) VALUES ($1,$2)`,
|
||
|
cfg.ServiceName, cfg.ServiceVersion)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
needsUpdate = true
|
||
|
} else if err == nil {
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
needsUpdate = (ver != cfg.ServiceVersion)
|
||
|
_, err := tx.Exec(ctx, `UPDATE pgxevent
|
||
|
SET version=$1 WHERE service=$2`,
|
||
|
cfg.ServiceVersion, cfg.ServiceName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
} else {
|
||
|
return err
|
||
|
}
|
||
|
if needsUpdate {
|
||
|
for key, value := range cfg.Listeners {
|
||
|
triggerName := "pgxevent:" + cfg.ServiceName + ":" + key
|
||
|
condBuffer := strings.Builder{}
|
||
|
err = value.ConditionTemplate.Execute(&condBuffer, "new")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
conditionNew := condBuffer.String()
|
||
|
condBuffer = strings.Builder{}
|
||
|
err = value.ConditionTemplate.Execute(&condBuffer, "old")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
conditionOld := condBuffer.String()
|
||
|
|
||
|
condition := "(" + conditionNew + ") and (old is null or not (" + conditionOld + "))"
|
||
|
|
||
|
_, err := tx.Exec(ctx,
|
||
|
`CREATE OR REPLACE FUNCTION "`+triggerName+`"()
|
||
|
RETURNS TRIGGER AS $$
|
||
|
DECLARE
|
||
|
BEGIN
|
||
|
IF (`+condition+`) THEN
|
||
|
PERFORM pg_notify('`+triggerName+`', CAST(NEW.`+value.PrimaryKeyColumn+` AS TEXT));
|
||
|
END IF;
|
||
|
RETURN NEW;
|
||
|
END $$ LANGUAGE PLPGSQL
|
||
|
`)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, err = tx.Exec(ctx, `DROP TRIGGER IF EXISTS "`+
|
||
|
triggerName+`" ON `+value.Table)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
_, err = tx.Exec(ctx, `CREATE TRIGGER "`+triggerName+`"`+`
|
||
|
AFTER INSERT OR UPDATE ON `+value.Table+`
|
||
|
FOR EACH ROW
|
||
|
EXECUTE PROCEDURE "`+triggerName+`"()
|
||
|
`)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
running := make(map[string]map[string]bool)
|
||
|
for key, _ := range cfg.Listeners {
|
||
|
triggerName := "pgxevent:" + cfg.ServiceName + ":" + key
|
||
|
_, err := conn.Exec(ctx, `LISTEN "`+triggerName+`"`)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
running[key] = make(map[string]bool)
|
||
|
}
|
||
|
var cascadingError error
|
||
|
startChannel := make(chan request)
|
||
|
completionChannel := make(chan completion)
|
||
|
counter := 2
|
||
|
teardown := false
|
||
|
|
||
|
go func() {
|
||
|
prefix := "pgxevent:" + cfg.ServiceName + ":"
|
||
|
for {
|
||
|
notification, err := conn.Conn().WaitForNotification(ctx)
|
||
|
if err != nil {
|
||
|
completionChannel <- completion{"", "", err}
|
||
|
return
|
||
|
}
|
||
|
startChannel <- request{notification.Channel[len(prefix):],
|
||
|
notification.Payload}
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
go func() {
|
||
|
for channel, notification := range cfg.Listeners {
|
||
|
condBuffer := strings.Builder{}
|
||
|
err := notification.ConditionTemplate.Execute(&condBuffer, "Tab")
|
||
|
if err != nil {
|
||
|
completionChannel <- completion{"", "", err}
|
||
|
return
|
||
|
}
|
||
|
rows, err := pool.Query(ctx, `SELECT CAST ( Tab.`+
|
||
|
notification.PrimaryKeyColumn+` AS text) as id FROM "`+
|
||
|
notification.Table+`" AS Tab WHERE`+
|
||
|
`(`+condBuffer.String()+`)`)
|
||
|
if err != nil {
|
||
|
completionChannel <- completion{"", "", err}
|
||
|
return
|
||
|
}
|
||
|
for rows.Next() {
|
||
|
var payload string
|
||
|
err = rows.Scan(&payload)
|
||
|
if err != nil {
|
||
|
completionChannel <- completion{"", "", err}
|
||
|
return
|
||
|
}
|
||
|
startChannel <- request{channel, payload}
|
||
|
}
|
||
|
|
||
|
}
|
||
|
completionChannel <- completion{}
|
||
|
return
|
||
|
}()
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case start := <-startChannel:
|
||
|
if !teardown {
|
||
|
if !running[start.Channel][start.Payload] {
|
||
|
counter++
|
||
|
running[start.Channel][start.Payload] = true
|
||
|
go func(req request) {
|
||
|
err := cfg.Listeners[start.Channel].
|
||
|
Handler.Handle(ctx, pool, req.Payload)
|
||
|
completionChannel <- completion{req.Channel, req.Payload, err}
|
||
|
}(start)
|
||
|
|
||
|
}
|
||
|
}
|
||
|
case complete := <-completionChannel:
|
||
|
if complete.Error != nil {
|
||
|
if !teardown {
|
||
|
teardown = true
|
||
|
cancel()
|
||
|
cascadingError = complete.Error
|
||
|
}
|
||
|
}
|
||
|
if complete.Channel != "" {
|
||
|
delete(running[complete.Channel], complete.Payload)
|
||
|
}
|
||
|
counter--
|
||
|
if counter == 0 {
|
||
|
return cascadingError
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type EventHandler interface {
|
||
|
Handle(ctx context.Context, pool *pgxpool.Pool, id string) error
|
||
|
}
|