// Go MySQL Driver - A MySQL-Driver for Go's database/sql package // // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. package mysql import ( "context" "database/sql" "database/sql/driver" "encoding/json" "io" "net" "strconv" "strings" "time" ) type mysqlConn struct { buf buffer netConn net.Conn rawConn net.Conn // underlying connection when netConn is TLS connection. result mysqlResult // managed by clearResult() and handleOkPacket(). cfg *Config connector *connector maxAllowedPacket int maxWriteSize int writeTimeout time.Duration flags clientFlag status statusFlag sequence uint8 parseTime bool // for context support (Go 1.8+) watching bool watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled closed atomicBool // set when conn is closed, before closech is closed } // Helper function to call per-connection logger. func (mc *mysqlConn) log(v ...any) { mc.cfg.Logger.Print(v...) } // Handles parameters set in DSN after the connection is established func (mc *mysqlConn) handleParams() (err error) { var cmdSet strings.Builder for param, val := range mc.cfg.Params { switch param { // Charset: character_set_connection, character_set_client, character_set_results case "charset": charsets := strings.Split(val, ",") for _, cs := range charsets { // ignore errors here - a charset may not exist if mc.cfg.Collation != "" { err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation) } else { err = mc.exec("SET NAMES " + cs) } if err == nil { break } } if err != nil { return } // Other system vars accumulated in a single SET command default: if cmdSet.Len() == 0 { // Heuristic: 29 chars for each other key=value to reduce reallocations cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1)) cmdSet.WriteString("SET ") } else { cmdSet.WriteString(", ") } cmdSet.WriteString(param) cmdSet.WriteString(" = ") cmdSet.WriteString(val) } } if cmdSet.Len() > 0 { err = mc.exec(cmdSet.String()) if err != nil { return } } return } func (mc *mysqlConn) markBadConn(err error) error { if mc == nil { return err } if err != errBadConnNoWrite { return err } return driver.ErrBadConn } func (mc *mysqlConn) Begin() (driver.Tx, error) { return mc.begin(false) } func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { if mc.closed.Load() { mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } var q string if readOnly { q = "START TRANSACTION READ ONLY" } else { q = "START TRANSACTION" } err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if !mc.closed.Load() { err = mc.writeCommandPacket(comQuit) } mc.cleanup() mc.clearResult() return } // Closes the network connection and unsets internal variables. Do not call this // function after successfully authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { if mc.closed.Swap(true) { return } // Makes cleanup idempotent close(mc.closech) conn := mc.rawConn if conn == nil { return } if err := conn.Close(); err != nil { mc.log(err) } // This function can be called from multiple goroutines. // So we can not mc.clearResult() here. // Caller should do it if they are in safe goroutine. } func (mc *mysqlConn) error() error { if mc.closed.Load() { if err := mc.canceled.Value(); err != nil { return err } return ErrInvalidConn } return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.closed.Load() { mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. mc.log(err) return nil, driver.ErrBadConn } stmt := &mysqlStmt{ mc: mc, } // Read Result columnCount, err := stmt.readPrepareResultPacket() if err == nil { if stmt.paramCount > 0 { if err = mc.readUntilEOF(); err != nil { return nil, err } } if columnCount > 0 { err = mc.readUntilEOF() } } return stmt, err } func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { // Number of ? should be same to len(args) if strings.Count(query, "?") != len(args) { return "", driver.ErrSkip } buf, err := mc.buf.takeCompleteBuffer() if err != nil { // can not take the buffer. Something must be wrong with the connection mc.log(err) return "", ErrInvalidConn } buf = buf[:0] argPos := 0 for i := 0; i < len(query); i++ { q := strings.IndexByte(query[i:], '?') if q == -1 { buf = append(buf, query[i:]...) break } buf = append(buf, query[i:i+q]...) i += q arg := args[argPos] argPos++ if arg == nil { buf = append(buf, "NULL"...) continue } switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) case uint64: // Handle uint64 explicitly because our custom ConvertValue emits unsigned values buf = strconv.AppendUint(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: if v { buf = append(buf, '1') } else { buf = append(buf, '0') } case time.Time: if v.IsZero() { buf = append(buf, "'0000-00-00'"...) } else { buf = append(buf, '\'') buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate) if err != nil { return "", err } buf = append(buf, '\'') } case json.RawMessage: buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { buf = escapeBytesBackslash(buf, v) } else { buf = escapeBytesQuotes(buf, v) } buf = append(buf, '\'') case []byte: if v == nil { buf = append(buf, "NULL"...) } else { buf = append(buf, "_binary'"...) if mc.status&statusNoBackslashEscapes == 0 { buf = escapeBytesBackslash(buf, v) } else { buf = escapeBytesQuotes(buf, v) } buf = append(buf, '\'') } case string: buf = append(buf, '\'') if mc.status&statusNoBackslashEscapes == 0 { buf = escapeStringBackslash(buf, v) } else { buf = escapeStringQuotes(buf, v) } buf = append(buf, '\'') default: return "", driver.ErrSkip } if len(buf)+4 > mc.maxAllowedPacket { return "", driver.ErrSkip } } if argPos != len(args) { return "", driver.ErrSkip } return string(buf), nil } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { if mc.closed.Load() { mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } query = prepared } err := mc.exec(query) if err == nil { copied := mc.result return &copied, err } return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { handleOk := mc.clearResult() // Send command if err := mc.writeCommandPacketStr(comQuery, query); err != nil { return mc.markBadConn(err) } // Read Result resLen, err := handleOk.readResultSetHeaderPacket() if err != nil { return err } if resLen > 0 { // columns if err := mc.readUntilEOF(); err != nil { return err } // rows if err := mc.readUntilEOF(); err != nil { return err } } return handleOk.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { return mc.query(query, args) } func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { handleOk := mc.clearResult() if mc.closed.Load() { mc.log(ErrInvalidConn) return nil, driver.ErrBadConn } if len(args) != 0 { if !mc.cfg.InterpolateParams { return nil, driver.ErrSkip } // try client-side prepare to reduce roundtrip prepared, err := mc.interpolateParams(query, args) if err != nil { return nil, err } query = prepared } // Send command err := mc.writeCommandPacketStr(comQuery, query) if err == nil { // Read Result var resLen int resLen, err = handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc if resLen == 0 { rows.rs.done = true switch err := rows.NextResultSet(); err { case nil, io.EOF: return rows, nil default: return nil, err } } // Columns rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command handleOk := mc.clearResult() if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result resLen, err := handleOk.readResultSetHeaderPacket() if err == nil { rows := new(textRows) rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns if err := mc.readUntilEOF(); err != nil { return nil, err } } dest := make([]driver.Value, resLen) if err = rows.readRow(dest); err == nil { return dest[0].([]byte), mc.readUntilEOF() } } return nil, err } // finish is called when the query has canceled. func (mc *mysqlConn) cancel(err error) { mc.canceled.Set(err) mc.cleanup() } // finish is called when the query has succeeded. func (mc *mysqlConn) finish() { if !mc.watching || mc.finished == nil { return } select { case mc.finished <- struct{}{}: mc.watching = false case <-mc.closech: } } // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) (err error) { if mc.closed.Load() { mc.log(ErrInvalidConn) return driver.ErrBadConn } if err = mc.watchCancel(ctx); err != nil { return } defer mc.finish() handleOk := mc.clearResult() if err = mc.writeCommandPacket(comPing); err != nil { return mc.markBadConn(err) } return handleOk.readResultOK() } // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if mc.closed.Load() { return nil, driver.ErrBadConn } if err := mc.watchCancel(ctx); err != nil { return nil, err } defer mc.finish() if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { level, err := mapIsolationLevel(opts.Isolation) if err != nil { return nil, err } err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) if err != nil { return nil, err } } return mc.begin(opts.ReadOnly) } func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := mc.watchCancel(ctx); err != nil { return nil, err } rows, err := mc.query(query, dargs) if err != nil { mc.finish() return nil, err } rows.finish = mc.finish return rows, err } func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := mc.watchCancel(ctx); err != nil { return nil, err } defer mc.finish() return mc.Exec(query, dargs) } func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if err := mc.watchCancel(ctx); err != nil { return nil, err } stmt, err := mc.Prepare(query) mc.finish() if err != nil { return nil, err } select { default: case <-ctx.Done(): stmt.Close() return nil, ctx.Err() } return stmt, nil } func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := stmt.mc.watchCancel(ctx); err != nil { return nil, err } rows, err := stmt.query(dargs) if err != nil { stmt.mc.finish() return nil, err } rows.finish = stmt.mc.finish return rows, err } func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { dargs, err := namedValueToValue(args) if err != nil { return nil, err } if err := stmt.mc.watchCancel(ctx); err != nil { return nil, err } defer stmt.mc.finish() return stmt.Exec(dargs) } func (mc *mysqlConn) watchCancel(ctx context.Context) error { if mc.watching { // Reach here if canceled, // so the connection is already invalid mc.cleanup() return nil } // When ctx is already cancelled, don't watch it. if err := ctx.Err(); err != nil { return err } // When ctx is not cancellable, don't watch it. if ctx.Done() == nil { return nil } // When watcher is not alive, can't watch it. if mc.watcher == nil { return nil } mc.watching = true mc.watcher <- ctx return nil } func (mc *mysqlConn) startWatcher() { watcher := make(chan context.Context, 1) mc.watcher = watcher finished := make(chan struct{}) mc.finished = finished go func() { for { var ctx context.Context select { case ctx = <-watcher: case <-mc.closech: return } select { case <-ctx.Done(): mc.cancel(ctx.Err()) case <-finished: case <-mc.closech: return } } }() } func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { nv.Value, err = converter{}.ConvertValue(nv.Value) return } // ResetSession implements driver.SessionResetter. // (From Go 1.10) func (mc *mysqlConn) ResetSession(ctx context.Context) error { if mc.closed.Load() { return driver.ErrBadConn } // Perform a stale connection check. We only perform this check for // the first query on a connection that has been checked out of the // connection pool: a fresh connection from the pool is more likely // to be stale, and it has not performed any previous writes that // could cause data corruption, so it's safe to return ErrBadConn // if the check fails. if mc.cfg.CheckConnLiveness { conn := mc.netConn if mc.rawConn != nil { conn = mc.rawConn } var err error if mc.cfg.ReadTimeout != 0 { err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) } if err == nil { err = connCheck(conn) } if err != nil { mc.log("closing bad idle connection: ", err) return driver.ErrBadConn } } return nil } // IsValid implements driver.Validator interface // (From Go 1.15) func (mc *mysqlConn) IsValid() bool { return !mc.closed.Load() }