|
| 1 | +package streaming |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "time" |
| 6 | + |
| 7 | + "github.com/redis/go-redis/v9/auth" |
| 8 | + "github.com/redis/go-redis/v9/internal/pool" |
| 9 | +) |
| 10 | + |
| 11 | +// Manager coordinates streaming credentials and re-authentication for a connection pool. |
| 12 | +// |
| 13 | +// The manager is responsible for: |
| 14 | +// - Creating and managing per-connection credentials listeners |
| 15 | +// - Providing the pool hook for re-authentication |
| 16 | +// - Coordinating between credentials updates and pool operations |
| 17 | +// |
| 18 | +// When credentials change via a StreamingCredentialsProvider: |
| 19 | +// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update |
| 20 | +// 2. It calls MarkForReAuth on the manager |
| 21 | +// 3. The manager delegates to the pool hook |
| 22 | +// 4. The pool hook schedules background re-authentication |
| 23 | +// |
| 24 | +// The manager maintains a registry of credentials listeners indexed by connection ID, |
| 25 | +// allowing listener reuse when connections are reinitialized (e.g., after handoff). |
| 26 | +type Manager struct { |
| 27 | + // credentialsListeners maps connection ID to credentials listener |
| 28 | + credentialsListeners *CredentialsListeners |
| 29 | + |
| 30 | + // pool is the connection pool being managed |
| 31 | + pool pool.Pooler |
| 32 | + |
| 33 | + // poolHookRef is the re-authentication pool hook |
| 34 | + poolHookRef *ReAuthPoolHook |
| 35 | +} |
| 36 | + |
| 37 | +// NewManager creates a new streaming credentials manager. |
| 38 | +// |
| 39 | +// Parameters: |
| 40 | +// - pl: The connection pool to manage |
| 41 | +// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication |
| 42 | +// |
| 43 | +// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that |
| 44 | +// re-auth operations don't exhaust the connection pool. |
| 45 | +func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager { |
| 46 | + m := &Manager{ |
| 47 | + pool: pl, |
| 48 | + poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout), |
| 49 | + credentialsListeners: NewCredentialsListeners(), |
| 50 | + } |
| 51 | + m.poolHookRef.manager = m |
| 52 | + return m |
| 53 | +} |
| 54 | + |
| 55 | +// PoolHook returns the pool hook for re-authentication. |
| 56 | +// |
| 57 | +// This hook should be registered with the connection pool to enable |
| 58 | +// automatic re-authentication when credentials change. |
| 59 | +func (m *Manager) PoolHook() pool.PoolHook { |
| 60 | + return m.poolHookRef |
| 61 | +} |
| 62 | + |
| 63 | +// Listener returns or creates a credentials listener for a connection. |
| 64 | +// |
| 65 | +// This method is called during connection initialization to set up the |
| 66 | +// credentials listener. If a listener already exists for the connection ID |
| 67 | +// (e.g., after a handoff), it is reused. |
| 68 | +// |
| 69 | +// Parameters: |
| 70 | +// - poolCn: The connection to create/get a listener for |
| 71 | +// - reAuth: Function to re-authenticate the connection with new credentials |
| 72 | +// - onErr: Function to call when re-authentication fails |
| 73 | +// |
| 74 | +// Returns: |
| 75 | +// - auth.CredentialsListener: The listener to subscribe to the credentials provider |
| 76 | +// - error: Non-nil if poolCn is nil |
| 77 | +// |
| 78 | +// Note: The reAuth and onErr callbacks are captured once when the listener is |
| 79 | +// created and reused for the connection's lifetime. They should not change. |
| 80 | +// |
| 81 | +// Thread-safe: Can be called concurrently during connection initialization. |
| 82 | +func (m *Manager) Listener( |
| 83 | + poolCn *pool.Conn, |
| 84 | + reAuth func(*pool.Conn, auth.Credentials) error, |
| 85 | + onErr func(*pool.Conn, error), |
| 86 | +) (auth.CredentialsListener, error) { |
| 87 | + if poolCn == nil { |
| 88 | + return nil, errors.New("poolCn cannot be nil") |
| 89 | + } |
| 90 | + connID := poolCn.GetID() |
| 91 | + // if we reconnect the underlying network connection, the streaming credentials listener will continue to work |
| 92 | + // so we can get the old listener from the cache and use it. |
| 93 | + // subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op |
| 94 | + listener, ok := m.credentialsListeners.Get(connID) |
| 95 | + if !ok || listener == nil { |
| 96 | + // Create new listener for this connection |
| 97 | + // Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime |
| 98 | + newCredListener := &ConnReAuthCredentialsListener{ |
| 99 | + conn: poolCn, |
| 100 | + reAuth: reAuth, |
| 101 | + onErr: onErr, |
| 102 | + manager: m, |
| 103 | + } |
| 104 | + |
| 105 | + m.credentialsListeners.Add(connID, newCredListener) |
| 106 | + listener = newCredListener |
| 107 | + } |
| 108 | + return listener, nil |
| 109 | +} |
| 110 | + |
| 111 | +// MarkForReAuth marks a connection for re-authentication. |
| 112 | +// |
| 113 | +// This method is called by the credentials listener when new credentials are |
| 114 | +// received. It delegates to the pool hook to schedule background re-authentication. |
| 115 | +// |
| 116 | +// Parameters: |
| 117 | +// - poolCn: The connection to re-authenticate |
| 118 | +// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails |
| 119 | +// |
| 120 | +// Thread-safe: Called by credentials listeners when credentials change. |
| 121 | +func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) { |
| 122 | + connID := poolCn.GetID() |
| 123 | + m.poolHookRef.MarkForReAuth(connID, reAuthFn) |
| 124 | +} |
| 125 | + |
| 126 | +// RemoveListener removes the credentials listener for a connection. |
| 127 | +// |
| 128 | +// This method is called by the pool hook's OnRemove to clean up listeners |
| 129 | +// when connections are removed from the pool. |
| 130 | +// |
| 131 | +// Parameters: |
| 132 | +// - connID: The connection ID whose listener should be removed |
| 133 | +// |
| 134 | +// Thread-safe: Called during connection removal. |
| 135 | +func (m *Manager) RemoveListener(connID uint64) { |
| 136 | + m.credentialsListeners.Remove(connID) |
| 137 | +} |
0 commit comments