Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions notify.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package pq
// This module contains support for Postgres LISTEN/NOTIFY.

import (
"context"
"database/sql/driver"
"errors"
"fmt"
Expand Down Expand Up @@ -40,6 +41,51 @@ func SetNotificationHandler(c driver.Conn, handler func(*Notification)) {
c.(*conn).notificationHandler = handler
}

// NotificationHandlerConnector wraps a regular connector and sets a notification handler
// on it.
type NotificationHandlerConnector struct {
driver.Connector
notificationHandler func(*Notification)
}

// Connect calls the underlying connector's connect method and then sets the
// notification handler.
func (n *NotificationHandlerConnector) Connect(ctx context.Context) (driver.Conn, error) {
c, err := n.Connector.Connect(ctx)
if err == nil {
SetNotificationHandler(c, n.notificationHandler)
}
return c, err
}

// ConnectorNotificationHandler returns the currently set notification handler, if any. If
// the given connector is not a result of ConnectorWithNotificationHandler, nil is
// returned.
func ConnectorNotificationHandler(c driver.Connector) func(*Notification) {
if c, ok := c.(*NotificationHandlerConnector); ok {
return c.notificationHandler
}
return nil
}

// ConnectorWithNotificationHandler creates or sets the given handler for the given
// connector. If the given connector is a result of calling this function
// previously, it is simply set on the given connector and returned. Otherwise,
// this returns a new connector wrapping the given one and setting the notification
// handler. A nil notification handler may be used to unset it.
//
// The returned connector is intended to be used with database/sql.OpenDB.
//
// Note: Notification handlers are executed synchronously by pq meaning commands
// won't continue to be processed until the handler returns.
func ConnectorWithNotificationHandler(c driver.Connector, handler func(*Notification)) *NotificationHandlerConnector {
if c, ok := c.(*NotificationHandlerConnector); ok {
c.notificationHandler = handler
return c
}
return &NotificationHandlerConnector{Connector: c, notificationHandler: handler}
}

const (
connStateIdle int32 = iota
connStateExpectResponse
Expand Down
42 changes: 42 additions & 0 deletions notify_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package pq

import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -568,3 +570,43 @@ func TestListenerPing(t *testing.T) {
t.Fatalf("expected errListenerClosed; got %v", err)
}
}

func TestConnectorWithNotificationHandler_Simple(t *testing.T) {
b, err := NewConnector("")
if err != nil {
t.Fatal(err)
}
var notification *Notification
// Make connector w/ handler to set the local var
c := ConnectorWithNotificationHandler(b, func(n *Notification) { notification = n })
sendNotification(c, t, "Test notification #1")
if notification == nil || notification.Extra != "Test notification #1" {
t.Fatalf("Expected notification w/ message, got %v", notification)
}
// Unset the handler on the same connector
prevC := c
if c = ConnectorWithNotificationHandler(c, nil); c != prevC {
t.Fatalf("Expected to not create new connector but did")
}
sendNotification(c, t, "Test notification #2")
if notification == nil || notification.Extra != "Test notification #1" {
t.Fatalf("Expected notification to not change, got %v", notification)
}
// Set it back on the same connector
if c = ConnectorWithNotificationHandler(c, func(n *Notification) { notification = n }); c != prevC {
t.Fatal("Expected to not create new connector but did")
}
sendNotification(c, t, "Test notification #3")
if notification == nil || notification.Extra != "Test notification #3" {
t.Fatalf("Expected notification w/ message, got %v", notification)
}
}

func sendNotification(c driver.Connector, t *testing.T, escapedNotification string) {
db := sql.OpenDB(c)
defer db.Close()
sql := fmt.Sprintf("LISTEN foo; NOTIFY foo, '%s';", escapedNotification)
if _, err := db.Exec(sql); err != nil {
t.Fatal(err)
}
}