@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172}
173173
174174// Closes the network connection and unsets internal variables. Do not call this
175- // function after successfully authentication, call Close instead. This function
175+ // function after successful authentication, call Close instead. This function
176176// is called before auth or on auth failure because MySQL will have already
177177// closed the network connection.
178178func (mc * mysqlConn ) cleanup () {
@@ -245,9 +245,105 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245245 return stmt , err
246246}
247247
248+ // findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249+ func findParamPositions (query string , noBackslashEscapes bool ) []int {
250+ const (
251+ stateNormal = iota
252+ stateString
253+ stateEscape
254+ stateEOLComment
255+ stateSlashStarComment
256+ stateBacktick
257+ )
258+
259+ var (
260+ QUOTE_BYTE = byte ('\'' )
261+ DBL_QUOTE_BYTE = byte ('"' )
262+ BACKSLASH_BYTE = byte ('\\' )
263+ QUESTION_MARK_BYTE = byte ('?' )
264+ SLASH_BYTE = byte ('/' )
265+ STAR_BYTE = byte ('*' )
266+ HASH_BYTE = byte ('#' )
267+ MINUS_BYTE = byte ('-' )
268+ LINE_FEED_BYTE = byte ('\n' )
269+ RADICAL_BYTE = byte ('`' )
270+ )
271+
272+ paramPositions := make ([]int , 0 )
273+ state := stateNormal
274+ singleQuotes := false
275+ lastChar := byte (0 )
276+ lenq := len (query )
277+ for i := 0 ; i < lenq ; i ++ {
278+ currentChar := query [i ]
279+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
280+ state = stateString
281+ lastChar = currentChar
282+ continue
283+ }
284+ switch currentChar {
285+ case STAR_BYTE :
286+ if state == stateNormal && lastChar == SLASH_BYTE {
287+ state = stateSlashStarComment
288+ }
289+ case SLASH_BYTE :
290+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
291+ state = stateNormal
292+ }
293+ case HASH_BYTE :
294+ if state == stateNormal {
295+ state = stateEOLComment
296+ }
297+ case MINUS_BYTE :
298+ if state == stateNormal && lastChar == MINUS_BYTE {
299+ state = stateEOLComment
300+ }
301+ case LINE_FEED_BYTE :
302+ if state == stateEOLComment {
303+ state = stateNormal
304+ }
305+ case DBL_QUOTE_BYTE :
306+ if state == stateNormal {
307+ state = stateString
308+ singleQuotes = false
309+ } else if state == stateString && ! singleQuotes {
310+ state = stateNormal
311+ } else if state == stateEscape {
312+ state = stateString
313+ }
314+ case QUOTE_BYTE :
315+ if state == stateNormal {
316+ state = stateString
317+ singleQuotes = true
318+ } else if state == stateString && singleQuotes {
319+ state = stateNormal
320+ } else if state == stateEscape {
321+ state = stateString
322+ }
323+ case BACKSLASH_BYTE :
324+ if state == stateString && ! noBackslashEscapes {
325+ state = stateEscape
326+ }
327+ case QUESTION_MARK_BYTE :
328+ if state == stateNormal {
329+ paramPositions = append (paramPositions , i )
330+ }
331+ case RADICAL_BYTE :
332+ if state == stateBacktick {
333+ state = stateNormal
334+ } else if state == stateNormal {
335+ state = stateBacktick
336+ }
337+ }
338+ lastChar = currentChar
339+ }
340+ return paramPositions
341+ }
342+
248343func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249- // Number of ? should be same to len(args)
250- if strings .Count (query , "?" ) != len (args ) {
344+ noBackslashEscapes := (mc .status & statusNoBackslashEscapes ) != 0
345+ paramPositions := findParamPositions (query , noBackslashEscapes )
346+ if len (paramPositions ) != len (args ) {
251347 return "" , driver .ErrSkip
252348 }
253349
@@ -261,21 +357,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261357 }
262358 buf = buf [:0 ]
263359 argPos := 0
360+ lastIdx := 0
264361
265- for i := 0 ; i < len (query ); i ++ {
266- q := strings .IndexByte (query [i :], '?' )
267- if q == - 1 {
268- buf = append (buf , query [i :]... )
269- break
270- }
271- buf = append (buf , query [i :i + q ]... )
272- i += q
273-
362+ for _ , qmIdx := range paramPositions {
363+ buf = append (buf , query [lastIdx :qmIdx ]... )
274364 arg := args [argPos ]
275365 argPos ++
276366
277367 if arg == nil {
278368 buf = append (buf , "NULL" ... )
369+ lastIdx = qmIdx + 1
279370 continue
280371 }
281372
@@ -306,30 +397,30 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
306397 }
307398 case json.RawMessage :
308399 buf = append (buf , '\'' )
309- if mc .status & statusNoBackslashEscapes == 0 {
310- buf = escapeBytesBackslash (buf , v )
311- } else {
400+ if noBackslashEscapes {
312401 buf = escapeBytesQuotes (buf , v )
402+ } else {
403+ buf = escapeBytesBackslash (buf , v )
313404 }
314405 buf = append (buf , '\'' )
315406 case []byte :
316407 if v == nil {
317408 buf = append (buf , "NULL" ... )
318409 } else {
319410 buf = append (buf , "_binary'" ... )
320- if mc .status & statusNoBackslashEscapes == 0 {
321- buf = escapeBytesBackslash (buf , v )
322- } else {
411+ if noBackslashEscapes {
323412 buf = escapeBytesQuotes (buf , v )
413+ } else {
414+ buf = escapeBytesBackslash (buf , v )
324415 }
325416 buf = append (buf , '\'' )
326417 }
327418 case string :
328419 buf = append (buf , '\'' )
329- if mc .status & statusNoBackslashEscapes == 0 {
330- buf = escapeStringBackslash (buf , v )
331- } else {
420+ if noBackslashEscapes {
332421 buf = escapeStringQuotes (buf , v )
422+ } else {
423+ buf = escapeStringBackslash (buf , v )
333424 }
334425 buf = append (buf , '\'' )
335426 default :
@@ -339,7 +430,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339430 if len (buf )+ 4 > mc .maxAllowedPacket {
340431 return "" , driver .ErrSkip
341432 }
433+ lastIdx = qmIdx + 1
342434 }
435+ buf = append (buf , query [lastIdx :]... )
343436 if argPos != len (args ) {
344437 return "" , driver .ErrSkip
345438 }
0 commit comments