diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 735e6369f4..ce5647aedf 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -78,18 +78,21 @@ type Conn struct { // background operations that may execute commands, like re-authentication. used atomic.Bool - // Inited flag to mark connection as initialized, this is almost the same as usable + // inited flag to mark connection as initialized, this is almost the same as usable // but it is used to make sure we don't initialize a network connection twice // On handoff, the network connection is replaced, but the Conn struct is reused // this flag will be set to false when the network connection is replaced and // set to true after the new network connection is initialized - Inited atomic.Bool + inited atomic.Bool - pooled bool - pubsub bool - closed atomic.Bool - createdAt time.Time - expiresAt time.Time + // Initializing flag to mark connection as initializing + // This is used to prevent concurrent initialization of the network connection + initializing atomic.Bool + pooled bool + pubsub bool + closed atomic.Bool + createdAt time.Time + expiresAt time.Time // maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers // Using atomic operations for lock-free access to avoid mutex contention @@ -306,7 +309,27 @@ func (cn *Conn) IsPubSub() bool { } func (cn *Conn) IsInited() bool { - return cn.Inited.Load() + return cn.inited.Load() +} + +func (cn *Conn) SetInited(inited bool) { + cn.inited.Store(inited) +} + +func (cn *Conn) CompareAndSwapInited(old, new bool) bool { + return cn.inited.CompareAndSwap(old, new) +} + +func (cn *Conn) IsInitializing() bool { + return cn.initializing.Load() +} + +func (cn *Conn) SetInitializing(initializing bool) { + cn.initializing.Store(initializing) +} + +func (cn *Conn) CompareAndSwapInitializing(old, new bool) bool { + return cn.initializing.CompareAndSwap(old, new) } // SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades. @@ -478,8 +501,17 @@ func (cn *Conn) GetNetConn() net.Conn { // SetNetConnAndInitConn replaces the underlying connection and executes the initialization. func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error { + // max retries of 100ms * 20 = 2 second + maxRetries := 20 + for cn.IsInitializing() || cn.IsUsed() { + time.Sleep(100 * time.Millisecond) + maxRetries-- + if maxRetries <= 0 { + return fmt.Errorf("failed to set net conn after %d attempts due to high contention", maxRetries) + } + } // New connection is not initialized yet - cn.Inited.Store(false) + cn.SetInited(false) // Replace the underlying connection cn.SetNetConn(netConn) return cn.ExecuteInitConn(ctx) diff --git a/redis.go b/redis.go index dcd7b59a78..0505ffe376 100644 --- a/redis.go +++ b/redis.go @@ -366,9 +366,20 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { } func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { - if !cn.Inited.CompareAndSwap(false, true) { + if !cn.CompareAndSwapInited(false, true) { return nil } + + defer func() { + // if the initialization did not complete successfully + // we need to mark the connection as not initialized + if cn.CompareAndSwapInitializing(true, false) { + internal.Logger.Printf(ctx, "redis: failed to initialize connection conn[%d]", cn.GetID()) + cn.SetInited(false) + } + }() + + cn.SetInitializing(true) var err error connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(c.opt, connPool, &c.hooksMixin) @@ -510,14 +521,14 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + // Set the connection initialization function for potential reconnections + cn.SetInitConnFunc(c.createInitConnFunc()) + // mark the connection as usable and inited // once returned to the pool as idle, this connection can be used by other clients cn.SetUsable(true) cn.SetUsed(false) - cn.Inited.Store(true) - - // Set the connection initialization function for potential reconnections - cn.SetInitConnFunc(c.createInitConnFunc()) + cn.SetInitializing(false) if c.opt.OnConnect != nil { return c.opt.OnConnect(ctx, conn)