From 67ab52b38a2002116920fb4711008af14c1f633e Mon Sep 17 00:00:00 2001 From: ljluestc Date: Sun, 2 Nov 2025 00:21:22 -0700 Subject: [PATCH] adding tests --- connection.go | 7 +- reprepare_test.go | 187 ++++++++++++++++++++++++++++++++++++++++++ statement.go | 204 +++++++++++++++++++++++++++++----------------- 3 files changed, 319 insertions(+), 79 deletions(-) create mode 100644 reprepare_test.go diff --git a/connection.go b/connection.go index 5648e47d..207844ef 100644 --- a/connection.go +++ b/connection.go @@ -216,9 +216,10 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { return nil, driver.ErrBadConn } - stmt := &mysqlStmt{ - mc: mc, - } + stmt := &mysqlStmt{ + mc: mc, + queryString: query, + } // Read Result columnCount, err := stmt.readPrepareResultPacket() diff --git a/reprepare_test.go b/reprepare_test.go new file mode 100644 index 00000000..588014a3 --- /dev/null +++ b/reprepare_test.go @@ -0,0 +1,187 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2025 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "testing" + "time" +) + +// Ensures that executing a prepared statement still returns correct data +// after a DDL that changes a column type. This validates automatic +// reprepare on ER_NEED_REPREPARE-capable servers and correctness in general. +func TestPreparedStmtReprepareAfterDDL(t *testing.T) { + runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) { + db := dbt.db + + dbt.mustExec("DROP TABLE IF EXISTS reprepare_test") + dbt.mustExec(` + CREATE TABLE reprepare_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + state TINYINT, + round TINYINT NOT NULL DEFAULT 0, + remark TEXT, + ctime TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + t.Cleanup(func() { db.Exec("DROP TABLE IF EXISTS reprepare_test") }) + + dbt.mustExec("INSERT INTO reprepare_test(state, round, remark) VALUES (1, 1, 'hello')") + + stmt, err := db.Prepare("SELECT state, round, remark, ctime FROM reprepare_test WHERE id=?") + if err != nil { + t.Fatalf("prepare failed: %v", err) + } + defer stmt.Close() + + var ( + s1, r1 int + rem1 string + ct1 time.Time + ) + if err := stmt.QueryRow(1).Scan(&s1, &r1, &rem1, &ct1); err != nil { + t.Fatalf("first scan failed: %v", err) + } + if s1 != 1 || r1 != 1 || rem1 != "hello" || ct1.IsZero() { + t.Fatalf("unexpected first row values: (%d,%d,%q,%v)", s1, r1, rem1, ct1) + } + + // Change the column type that participates in the prepared statement's result set. + dbt.mustExec("ALTER TABLE reprepare_test MODIFY state INT") + + var ( + s2, r2 int + rem2 string + ct2 time.Time + ) + // This used to fail or return incorrect data on some servers without reprepare handling. + if err := stmt.QueryRow(1).Scan(&s2, &r2, &rem2, &ct2); err != nil { + // Some environments may not reproduce ER_NEED_REPREPARE, so avoid flakiness by surfacing the error. + t.Fatalf("second scan failed: %v", err) + } + + if s2 != s1 || r2 != r1 || rem2 != rem1 || ct2.IsZero() { + t.Fatalf("unexpected second row values after DDL: got (%d,%d,%q,%v), want (%d,%d,%q,)", + s2, r2, rem2, ct2, s1, r1, rem1, + ) + } + }) +} + +// Validates Exec path also reprovisions the prepared statement after DDL. +func TestPreparedStmtExecReprepareAfterDDL(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + db := dbt.db + + dbt.mustExec("DROP TABLE IF EXISTS reprepare_exec_test") + dbt.mustExec(` + CREATE TABLE reprepare_exec_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + value INT NOT NULL + )`) + t.Cleanup(func() { db.Exec("DROP TABLE IF EXISTS reprepare_exec_test") }) + + stmt, err := db.Prepare("INSERT INTO reprepare_exec_test(value) VALUES (?)") + if err != nil { + t.Fatalf("prepare failed: %v", err) + } + defer stmt.Close() + + if _, err := stmt.Exec(1); err != nil { + t.Fatalf("first exec failed: %v", err) + } + + // Change the column type to trigger metadata invalidation on some servers. + dbt.mustExec("ALTER TABLE reprepare_exec_test MODIFY value BIGINT") + + if _, err := stmt.Exec(2); err != nil { + t.Fatalf("second exec (after DDL) failed: %v", err) + } + + // Verify both rows are present and correct. + rows := dbt.mustQuery("SELECT value FROM reprepare_exec_test ORDER BY id") + defer rows.Close() + var got []int + for rows.Next() { + var v int + if err := rows.Scan(&v); err != nil { + t.Fatalf("scan values failed: %v", err) + } + got = append(got, v) + } + if len(got) != 2 || got[0] != 1 || got[1] != 2 { + t.Fatalf("unexpected values: %v", got) + } + }) +} + +// Ensures repeated scans using the same prepared statement remain correct across DDL, scanning into sql.NullTime. +func TestPreparedStmtReprepareMultipleScansAfterDDL_NullTime(t *testing.T) { + runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) { + db := dbt.db + + dbt.mustExec("DROP TABLE IF EXISTS reprepare_multi_test") + dbt.mustExec(` + CREATE TABLE reprepare_multi_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + state TINYINT, + ctime TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )`) + t.Cleanup(func() { db.Exec("DROP TABLE IF EXISTS reprepare_multi_test") }) + + dbt.mustExec("INSERT INTO reprepare_multi_test(state) VALUES (5)") + + stmt, err := db.Prepare("SELECT state, ctime FROM reprepare_multi_test WHERE id=?") + if err != nil { + t.Fatalf("prepare failed: %v", err) + } + defer stmt.Close() + + // First scan + { + var s int + var ct sql.NullTime + if err := stmt.QueryRow(1).Scan(&s, &ct); err != nil { + t.Fatalf("first scan failed: %v", err) + } + if s != 5 || !ct.Valid || ct.Time.IsZero() { + t.Fatalf("unexpected first values: (%d,%v)", s, ct) + } + } + + // DDL change that alters one of the selected column types + dbt.mustExec("ALTER TABLE reprepare_multi_test MODIFY state INT") + + // Second scan after DDL + { + var s int + var ct sql.NullTime + if err := stmt.QueryRow(1).Scan(&s, &ct); err != nil { + t.Fatalf("second scan failed: %v", err) + } + if s != 5 || !ct.Valid || ct.Time.IsZero() { + t.Fatalf("unexpected second values after DDL: (%d,%v)", s, ct) + } + } + + // Third scan to ensure continued usability + { + var s int + var ct sql.NullTime + if err := stmt.QueryRow(1).Scan(&s, &ct); err != nil { + t.Fatalf("third scan failed: %v", err) + } + if s != 5 || !ct.Valid || ct.Time.IsZero() { + t.Fatalf("unexpected third values after DDL: (%d,%v)", s, ct) + } + } + }) +} + + diff --git a/statement.go b/statement.go index 2db8960e..4fc2b6f9 100644 --- a/statement.go +++ b/statement.go @@ -21,6 +21,7 @@ type mysqlStmt struct { id uint32 paramCount int columns []mysqlField + queryString string } func (stmt *mysqlStmt) Close() error { @@ -52,49 +53,61 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.closed.Load() { - return nil, driver.ErrBadConn - } - // Send command - err := stmt.writeExecutePacket(args) - if err != nil { - return nil, stmt.mc.markBadConn(err) - } + for attempt := 0; attempt < 2; attempt++ { + if stmt.mc.closed.Load() { + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } - mc := stmt.mc - handleOk := stmt.mc.clearResult() + mc := stmt.mc + handleOk := stmt.mc.clearResult() - // Read Result - resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() - if err != nil { - return nil, err - } + // Read Result + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() + if err != nil { + if me, ok := err.(*MySQLError); ok && me.Number == 1615 /* ER_NEED_REPREPARE */ { + if rerr := stmt.reprepare(); rerr == nil { + // Retry once after successful reprepare + continue + } else { + return nil, rerr + } + } + return nil, err + } - if resLen > 0 { - // Columns - if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { - // we can not skip column metadata because next stmt.Query() may use it. - if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { - return nil, err - } - } else { - if err = mc.skipColumns(resLen); err != nil { - return nil, err - } - } + if resLen > 0 { + // Columns + if metadataFollows && stmt.mc.extCapabilities&clientCacheMetadata != 0 { + // we can not skip column metadata because next stmt.Query() may use it. + if stmt.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { + return nil, err + } + } else { + if err = mc.skipColumns(resLen); err != nil { + return nil, err + } + } - // Rows - if err = mc.skipRows(); err != nil { - return nil, err - } - } + // Rows + if err = mc.skipRows(); err != nil { + return nil, err + } + } - if err := handleOk.discardResults(); err != nil { - return nil, err - } + if err := handleOk.discardResults(); err != nil { + return nil, err + } - copied := mc.result - return &copied, nil + copied := mc.result + return &copied, nil + } + // Should not reach here + return nil, ErrInvalidConn } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { @@ -102,51 +115,90 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { } func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { - if stmt.mc.closed.Load() { - return nil, driver.ErrBadConn - } - // Send command - err := stmt.writeExecutePacket(args) - if err != nil { - return nil, stmt.mc.markBadConn(err) - } + for attempt := 0; attempt < 2; attempt++ { + if stmt.mc.closed.Load() { + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } - mc := stmt.mc + mc := stmt.mc - // Read Result - handleOk := stmt.mc.clearResult() - resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() - if err != nil { - return nil, err - } + // Read Result + handleOk := stmt.mc.clearResult() + resLen, metadataFollows, err := handleOk.readResultSetHeaderPacket() + if err != nil { + if me, ok := err.(*MySQLError); ok && me.Number == 1615 /* ER_NEED_REPREPARE */ { + if rerr := stmt.reprepare(); rerr == nil { + // Retry once after successful reprepare + continue + } else { + return nil, rerr + } + } + return nil, err + } - rows := new(binaryRows) + rows := new(binaryRows) - if resLen > 0 { - rows.mc = mc - if metadataFollows { - if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { - return nil, err - } - stmt.columns = rows.rs.columns - } else { - if err = mc.skipEof(); err != nil { - return nil, err - } - rows.rs.columns = stmt.columns - } - } else { - rows.rs.done = true + if resLen > 0 { + rows.mc = mc + if metadataFollows { + if rows.rs.columns, err = mc.readColumns(resLen, stmt.columns); err != nil { + return nil, err + } + stmt.columns = rows.rs.columns + } else { + if err = mc.skipEof(); err != nil { + return nil, err + } + rows.rs.columns = stmt.columns + } + } else { + rows.rs.done = true - switch err := rows.NextResultSet(); err { - case nil, io.EOF: - return rows, nil - default: - return nil, err - } - } + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } + } + + return rows, err + } + return nil, ErrInvalidConn +} + +// reprepare closes the current prepared statement on the server and prepares it again. +// It preserves the receiver so that callers can continue to use the same *mysqlStmt. +func (stmt *mysqlStmt) reprepare() error { + if stmt.mc == nil || stmt.queryString == "" { + return ErrInvalidConn + } + // Prepare a new statement on the same connection. + newStmt, err := stmt.mc.Prepare(stmt.queryString) + if err != nil { + return err + } + // Close the old statement id on the server. + oldID := stmt.id + _ = stmt.mc.writeCommandPacketUint32(comStmtClose, oldID) - return rows, err + // Copy fields from the newly prepared statement into this receiver. + if ns, ok := newStmt.(*mysqlStmt); ok { + stmt.id = ns.id + stmt.paramCount = ns.paramCount + stmt.columns = ns.columns + // Detach the temporary stmt to avoid double-close semantics. + ns.mc = nil + return nil + } + // Should not happen as we control Prepare, but guard anyway. + return ErrInvalidConn } var jsonType = reflect.TypeOf(json.RawMessage{})