diff options
-rw-r--r-- | auth.go | 438 | ||||
-rw-r--r-- | auth_test.go | 1330 | ||||
-rw-r--r-- | benchmark_test.go | 374 | ||||
-rw-r--r-- | buffer.go | 182 | ||||
-rw-r--r-- | collations.go | 265 | ||||
-rw-r--r-- | conncheck.go | 54 | ||||
-rw-r--r-- | conncheck_dummy.go | 17 | ||||
-rw-r--r-- | conncheck_test.go | 38 | ||||
-rw-r--r-- | connection.go | 650 | ||||
-rw-r--r-- | connection_test.go | 203 | ||||
-rw-r--r-- | connector.go | 146 | ||||
-rw-r--r-- | connector_test.go | 30 | ||||
-rw-r--r-- | const.go | 174 | ||||
-rw-r--r-- | driver.go | 107 | ||||
-rw-r--r-- | driver_test.go | 3211 | ||||
-rw-r--r-- | dsn.go | 559 | ||||
-rw-r--r-- | dsn_test.go | 416 | ||||
-rw-r--r-- | errors.go | 77 | ||||
-rw-r--r-- | errors_test.go | 61 | ||||
-rw-r--r-- | fields.go | 206 | ||||
-rw-r--r-- | fuzz.go | 24 | ||||
-rw-r--r-- | go.mod | 8 | ||||
-rw-r--r-- | go.sum | 4 | ||||
-rw-r--r-- | infile.go | 184 | ||||
-rw-r--r-- | nulltime.go | 71 | ||||
-rw-r--r-- | nulltime_test.go | 62 | ||||
-rw-r--r-- | packets.go | 1349 | ||||
-rw-r--r-- | packets_test.go | 336 | ||||
-rw-r--r-- | result.go | 22 | ||||
-rw-r--r-- | rows.go | 223 | ||||
-rw-r--r-- | statement.go | 220 | ||||
-rw-r--r-- | statement_test.go | 151 | ||||
-rw-r--r-- | transaction.go | 31 | ||||
-rw-r--r-- | utils.go | 874 | ||||
-rw-r--r-- | utils_test.go | 510 |
35 files changed, 12607 insertions, 0 deletions
@@ -0,0 +1,438 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "fmt" + "sync" +) + +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + +// Hash password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Hash password using insecure pre 4.1 method +func scrambleOldPassword(scramble []byte, password string) []byte { + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +// Hash password using 4.1+ method (SHA1) +func scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + if err != nil { + return err + } + return mc.writeAuthSwitchPacket(enc) +} + +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { + switch plugin { + case "caching_sha2_password": + authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + return authResp, nil + + case "mysql_old_password": + if !mc.cfg.AllowOldPasswords { + return nil, ErrOldPassword + } + if len(mc.cfg.Passwd) == 0 { + return nil, nil + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil + + case "mysql_clear_password": + if !mc.cfg.AllowCleartextPasswords { + return nil, ErrCleartextPassword + } + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + return append([]byte(mc.cfg.Passwd), 0), nil + + case "mysql_native_password": + if !mc.cfg.AllowNativePasswords { + return nil, ErrNativePassword + } + // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // Native password authentication only need and will need 20-byte challenge. + authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + return authResp, nil + + case "sha256_password": + if len(mc.cfg.Passwd) == 0 { + return []byte{0}, nil + } + // unlike caching_sha2_password, sha256_password does not accept + // cleartext password on unix transport. + if mc.cfg.tls != nil { + // write cleartext auth packet + return append([]byte(mc.cfg.Passwd), 0), nil + } + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, err + + default: + errLog.Print("unknown auth plugin:", plugin) + return nil, ErrUnknownPlugin + } +} + +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { + // Read Result Packet + authData, newPlugin, err := mc.readAuthResult() + if err != nil { + return err + } + + // handle auth plugin switch, if requested + if newPlugin != "" { + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if authData == nil { + authData = oldAuthData + } else { + // copy data from read buffer to owned slice + copy(oldAuthData, authData) + } + + plugin = newPlugin + + authResp, err := mc.auth(authData, plugin) + if err != nil { + return err + } + if err = mc.writeAuthSwitchPacket(authResp); err != nil { + return err + } + + // Read Result Packet + authData, newPlugin, err = mc.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if newPlugin != "" { + return ErrMalformPkt + } + } + + switch plugin { + + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + case "caching_sha2_password": + switch len(authData) { + case 0: + return nil // auth successful + case 1: + switch authData[0] { + case cachingSha2PasswordFastAuthSuccess: + if err = mc.readResultOK(); err == nil { + return nil // auth successful + } + + case cachingSha2PasswordPerformFullAuthentication: + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) + if err != nil { + return err + } + } else { + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } + data[4] = cachingSha2PasswordRequestPublicKey + err = mc.writePacket(data) + if err != nil { + return err + } + + if data, err = mc.readPacket(); err != nil { + return err + } + + if data[0] != iAuthMoreData { + return fmt.Errorf("unexpect resp from server for caching_sha2_password perform full authentication") + } + + // parse public key + block, rest := pem.Decode(data[1:]) + if block == nil { + return fmt.Errorf("No Pem data found, data: %s", rest) + } + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pubKey) + if err != nil { + return err + } + } + return mc.readResultOK() + + default: + return ErrMalformPkt + } + default: + return ErrMalformPkt + } + + case "sha256_password": + switch len(authData) { + case 0: + return nil // auth successful + default: + block, _ := pem.Decode(authData) + if block == nil { + return fmt.Errorf("no Pem data found, data: %s", authData) + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + return mc.readResultOK() + } + + default: + return nil // auth successful + } + + return err +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..feba305 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,1330 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "testing" +) + +var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" + + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAol0Z8G8U+25Btxk/g/fm\n" + + "UAW/wEKjQCTjkibDE4B+qkuWeiumg6miIRhtilU6m9BFmLQSy1ltYQuu4k17A4tQ\n" + + "rIPpOQYZges/qsDFkZh3wyK5jL5WEFVdOasf6wsfszExnPmcZS4axxoYJfiuilrN\n" + + "hnwinBAqfi3S0sw5MpSI4Zl1AbOrHG4zDI62Gti2PKiMGyYDZTS9xPrBLbN95Kby\n" + + "FFclQLEzA9RJcS1nHFsWtRgHjGPhhjCQxEm9NQ1nePFhCfBfApyfH1VM2VCOQum6\n" + + "Ci9bMuHWjTjckC84mzF99kOxOWVU7mwS6gnJqBzpuz8t3zq8/iQ2y7QrmZV+jTJP\n" + + "WQIDAQAB\n" + + "-----END PUBLIC KEY-----\n") + +var testPubKeyRSA *rsa.PublicKey + +func init() { + block, _ := pem.Decode(testPubKey) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + testPubKeyRSA = pub.(*rsa.PublicKey) +} + +func TestScrambleOldPass(t *testing.T) { + scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2} + vectors := []struct { + pass string + out string + }{ + {" pass", "47575c5a435b4251"}, + {"pass ", "47575c5a435b4251"}, + {"123\t456", "575c47505b5b5559"}, + {"C0mpl!ca ted#PASS123", "5d5d554849584a45"}, + } + for _, tuple := range vectors { + ours := scrambleOldPassword(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed old password %q", tuple.pass) + } + } +} + +func TestScrambleSHA256Pass(t *testing.T) { + scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + vectors := []struct { + pass string + out string + }{ + {"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"}, + {"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"}, + } + for _, tuple := range vectors { + ours := scrambleSHA256Password(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed SHA256 password %q", tuple.pass) + } + } +} + +func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56, + 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15, + 184, 150, 26, 61, 57, 235} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 3, // Fast Auth Success + 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // pub key response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthFastCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowNativePasswords = false + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthFastNativePassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252, + 172, 50, 211, 192, 240, 164, 26, 48, 207, 45} + if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{1} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + // hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + + // unset TLS config to prevent the actual establishment of a TLS wrapper + mc.cfg.tls = nil + + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...), + + // OK + {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 4 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Pub Key Request + 1, 0, 0, 5, 2, + + // 3. Packet: Encrypted Password + 0, 1, 0, 7, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {2, 0, 0, 4, 1, 4}, // Perform Full Authentication + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Cleartext password + 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + conn.maxReads = 1 + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthSwitchCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = false + + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + conn.maxReads = 1 + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthSwitchNativePassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103, + 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +func TestAuthSwitchOldPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} +func TestAuthSwitchOldPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Empty Password + 1, 0, 0, 3, 0, + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Pub Key Request + 1, 0, 0, 3, 1, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Encrypted Password + 0, 1, 0, 3, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Cleartext Password + 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..2439af9 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,374 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "fmt" + "math" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type TB testing.B + +func (tb *TB) check(err error) { + if err != nil { + tb.Fatal(err) + } +} + +func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB { + tb.check(err) + return db +} + +func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows { + tb.check(err) + return rows +} + +func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { + tb.check(err) + return stmt +} + +func initDB(b *testing.B, queries ...string) *sql.DB { + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + for _, query := range queries { + if _, err := db.Exec(query); err != nil { + b.Fatalf("error on %q: %v", query, err) + } + } + return db +} + +const concurrencyLevel = 10 + +func BenchmarkQuery(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + var got string + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Errorf("query = %q; want one", got) + wg.Done() + return + } + } + }() + } +} + +func BenchmarkExec(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := tb.checkDB(sql.Open("mysql", dsn)) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("DO 1")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + if _, err := stmt.Exec(); err != nil { + b.Logf("stmt.Exec failed: %v", err) + b.Fail() + } + } + }() + } +} + +// data, but no db writes +var roundtripSample []byte + +func initRoundtripBenchmarks() ([]byte, int, int) { + if roundtripSample == nil { + roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024)) + } + return roundtripSample, 16, len(roundtripSample) +} + +func BenchmarkRoundtripTxt(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + sampleString := string(sample) + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + b.StartTimer() + var result string + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sampleString[0:length] + rows := tb.checkRows(db.Query(`SELECT "` + test + `"`)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if result != test { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkRoundtripBin(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + stmt := tb.checkStmt(db.Prepare("SELECT ?")) + defer stmt.Close() + b.StartTimer() + var result sql.RawBytes + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sample[0:length] + rows := tb.checkRows(stmt.Query(test)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if !bytes.Equal(result, test) { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkInterpolation(b *testing.B) { + mc := &mysqlConn{ + cfg: &Config{ + InterpolateParams: true, + Loc: time.UTC, + }, + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + buf: newBuffer(nil), + } + + args := []driver.Value{ + int64(42424242), + float64(math.Pi), + false, + time.Unix(1423411542, 807015000), + []byte("bytes containing special chars ' \" \a \x00"), + "string containing special chars ' \" \a \x00", + } + q := "SELECT ?, ?, ?, ?, ?, ?" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := mc.interpolateParams(q, args) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var got string + for pb.Next() { + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Fatalf("query = %q; want one", got) + } + } + }) +} + +func BenchmarkQueryContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := stmt.ExecContext(ctx); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkExecContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + b.ReportAllocs() + b.ResetTimer() + + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..a883551 --- /dev/null +++ b/buffer.go @@ -0,0 +1,182 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "io" + "net" + "time" +) + +const defaultBufSize = 4096 +const maxCachedBufSize = 256 * 1024 + +// A buffer which is used for both reading and writing. +// This is possible since communication on each connection is synchronous. +// In other words, we can't write and read simultaneously on the same connection. +// The buffer is similar to bufio.Reader / Writer but zero-copy-ish +// Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme +type buffer struct { + buf []byte // buf is a byte buffer who's length and capacity are equal. + nc net.Conn + idx int + length int + timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering +} + +// newBuffer allocates and returns a new buffer. +func newBuffer(nc net.Conn) buffer { + fg := make([]byte, defaultBufSize) + return buffer{ + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, + } +} + +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + +// fill reads into the buffer until at least _need_ bytes are in it +func (b *buffer) fill(need int) error { + n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] + + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { + // Round up to the next multiple of the default size + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } + } + + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[:n], b.buf[b.idx:]) + } + + b.buf = dest + b.idx = 0 + + for { + if b.timeout > 0 { + if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + return err + } + } + + nn, err := b.nc.Read(b.buf[n:]) + n += nn + + switch err { + case nil: + if n < need { + continue + } + b.length = n + return nil + + case io.EOF: + if n >= need { + b.length = n + return nil + } + return io.ErrUnexpectedEOF + + default: + return err + } + } +} + +// returns next N bytes from buffer. +// The returned slice is only guaranteed to be valid until the next read +func (b *buffer) readNext(need int) ([]byte, error) { + if b.length < need { + // refill + if err := b.fill(need); err != nil { + return nil, err + } + } + + offset := b.idx + b.idx += need + b.length -= need + return b.buf[offset:b.idx], nil +} + +// takeBuffer returns a buffer with the requested size. +// If possible, a slice from the existing buffer is returned. +// Otherwise a bigger buffer is made. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeBuffer(length int) ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + + // test (cheap) general case first + if length <= cap(b.buf) { + return b.buf[:length], nil + } + + if length < maxPacketSize { + b.buf = make([]byte, length) + return b.buf, nil + } + + // buffer is larger than we want to store. + return make([]byte, length), nil +} + +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf[:length], nil +} + +// takeCompleteBuffer returns the complete existing buffer. +// This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. +// Only one buffer (total) can be used at a time. +func (b *buffer) takeCompleteBuffer() ([]byte, error) { + if b.length > 0 { + return nil, ErrBusyBuffer + } + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { + if b.length > 0 { + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] + } + return nil +} diff --git a/collations.go b/collations.go new file mode 100644 index 0000000..650c203 --- /dev/null +++ b/collations.go @@ -0,0 +1,265 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2014 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +const defaultCollation = "utf8mb4_general_ci" +const binaryCollation = "binary" + +// A list of available collations mapped to the internal ID. +// To update this map use the following MySQL query: +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. +// +// ucs2, utf16, and utf32 can't be used for connection charset. +// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset +// They are commented out to reduce this map. +var collations = map[string]byte{ + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + //"ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + //"utf16_general_ci": 54, + //"utf16_bin": 55, + //"utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + //"utf32_general_ci": 60, + //"utf32_bin": 61, + //"utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "utf8_tolower_ci": 76, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + //"ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + //"utf16_unicode_ci": 101, + //"utf16_icelandic_ci": 102, + //"utf16_latvian_ci": 103, + //"utf16_romanian_ci": 104, + //"utf16_slovenian_ci": 105, + //"utf16_polish_ci": 106, + //"utf16_estonian_ci": 107, + //"utf16_spanish_ci": 108, + //"utf16_swedish_ci": 109, + //"utf16_turkish_ci": 110, + //"utf16_czech_ci": 111, + //"utf16_danish_ci": 112, + //"utf16_lithuanian_ci": 113, + //"utf16_slovak_ci": 114, + //"utf16_spanish2_ci": 115, + //"utf16_roman_ci": 116, + //"utf16_persian_ci": 117, + //"utf16_esperanto_ci": 118, + //"utf16_hungarian_ci": 119, + //"utf16_sinhala_ci": 120, + //"utf16_german2_ci": 121, + //"utf16_croatian_ci": 122, + //"utf16_unicode_520_ci": 123, + //"utf16_vietnamese_ci": 124, + //"ucs2_unicode_ci": 128, + //"ucs2_icelandic_ci": 129, + //"ucs2_latvian_ci": 130, + //"ucs2_romanian_ci": 131, + //"ucs2_slovenian_ci": 132, + //"ucs2_polish_ci": 133, + //"ucs2_estonian_ci": 134, + //"ucs2_spanish_ci": 135, + //"ucs2_swedish_ci": 136, + //"ucs2_turkish_ci": 137, + //"ucs2_czech_ci": 138, + //"ucs2_danish_ci": 139, + //"ucs2_lithuanian_ci": 140, + //"ucs2_slovak_ci": 141, + //"ucs2_spanish2_ci": 142, + //"ucs2_roman_ci": 143, + //"ucs2_persian_ci": 144, + //"ucs2_esperanto_ci": 145, + //"ucs2_hungarian_ci": 146, + //"ucs2_sinhala_ci": 147, + //"ucs2_german2_ci": 148, + //"ucs2_croatian_ci": 149, + //"ucs2_unicode_520_ci": 150, + //"ucs2_vietnamese_ci": 151, + //"ucs2_general_mysql500_ci": 159, + //"utf32_unicode_ci": 160, + //"utf32_icelandic_ci": 161, + //"utf32_latvian_ci": 162, + //"utf32_romanian_ci": 163, + //"utf32_slovenian_ci": 164, + //"utf32_polish_ci": 165, + //"utf32_estonian_ci": 166, + //"utf32_spanish_ci": 167, + //"utf32_swedish_ci": 168, + //"utf32_turkish_ci": 169, + //"utf32_czech_ci": 170, + //"utf32_danish_ci": 171, + //"utf32_lithuanian_ci": 172, + //"utf32_slovak_ci": 173, + //"utf32_spanish2_ci": 174, + //"utf32_roman_ci": 175, + //"utf32_persian_ci": 176, + //"utf32_esperanto_ci": 177, + //"utf32_hungarian_ci": 178, + //"utf32_sinhala_ci": 179, + //"utf32_german2_ci": 180, + //"utf32_croatian_ci": 181, + //"utf32_unicode_520_ci": 182, + //"utf32_vietnamese_ci": 183, + "utf8_unicode_ci": 192, + "utf8_icelandic_ci": 193, + "utf8_latvian_ci": 194, + "utf8_romanian_ci": 195, + "utf8_slovenian_ci": 196, + "utf8_polish_ci": 197, + "utf8_estonian_ci": 198, + "utf8_spanish_ci": 199, + "utf8_swedish_ci": 200, + "utf8_turkish_ci": 201, + "utf8_czech_ci": 202, + "utf8_danish_ci": 203, + "utf8_lithuanian_ci": 204, + "utf8_slovak_ci": 205, + "utf8_spanish2_ci": 206, + "utf8_roman_ci": 207, + "utf8_persian_ci": 208, + "utf8_esperanto_ci": 209, + "utf8_hungarian_ci": 210, + "utf8_sinhala_ci": 211, + "utf8_german2_ci": 212, + "utf8_croatian_ci": 213, + "utf8_unicode_520_ci": 214, + "utf8_vietnamese_ci": 215, + "utf8_general_mysql500_ci": 223, + "utf8mb4_unicode_ci": 224, + "utf8mb4_icelandic_ci": 225, + "utf8mb4_latvian_ci": 226, + "utf8mb4_romanian_ci": 227, + "utf8mb4_slovenian_ci": 228, + "utf8mb4_polish_ci": 229, + "utf8mb4_estonian_ci": 230, + "utf8mb4_spanish_ci": 231, + "utf8mb4_swedish_ci": 232, + "utf8mb4_turkish_ci": 233, + "utf8mb4_czech_ci": 234, + "utf8mb4_danish_ci": 235, + "utf8mb4_lithuanian_ci": 236, + "utf8mb4_slovak_ci": 237, + "utf8mb4_spanish2_ci": 238, + "utf8mb4_roman_ci": 239, + "utf8mb4_persian_ci": 240, + "utf8mb4_esperanto_ci": 241, + "utf8mb4_hungarian_ci": 242, + "utf8mb4_sinhala_ci": 243, + "utf8mb4_german2_ci": 244, + "utf8mb4_croatian_ci": 245, + "utf8mb4_unicode_520_ci": 246, + "utf8mb4_vietnamese_ci": 247, + "gb18030_chinese_ci": 248, + "gb18030_bin": 249, + "gb18030_unicode_520_ci": 250, + "utf8mb4_0900_ai_ci": 255, +} + +// A denylist of collations which is unsafe to interpolate parameters. +// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. +var unsafeCollations = map[string]bool{ + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, + "gb18030_chinese_ci": true, + "gb18030_bin": true, + "gb18030_unicode_520_ci": true, +} diff --git a/conncheck.go b/conncheck.go new file mode 100644 index 0000000..912d909 --- /dev/null +++ b/conncheck.go @@ -0,0 +1,54 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package mariadb + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(conn net.Conn) error { + var sysErr error + + sysConn, ok := conn.(syscall.Conn) + if !ok { + return nil + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return err + } + + err = rawConn.Read(func(fd uintptr) bool { + var buf [1]byte + n, err := syscall.Read(int(fd), buf[:]) + switch { + case n == 0 && err == nil: + sysErr = io.EOF + case n > 0: + sysErr = errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + sysErr = nil + default: + sysErr = err + } + return true + }) + if err != nil { + return err + } + + return sysErr +} diff --git a/conncheck_dummy.go b/conncheck_dummy.go new file mode 100644 index 0000000..7ad97ed --- /dev/null +++ b/conncheck_dummy.go @@ -0,0 +1,17 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !linux,!darwin,!dragonfly,!freebsd,!netbsd,!openbsd,!solaris,!illumos + +package mariadb + +import "net" + +func connCheck(conn net.Conn) error { + return nil +} diff --git a/conncheck_test.go b/conncheck_test.go new file mode 100644 index 0000000..3000b38 --- /dev/null +++ b/conncheck_test.go @@ -0,0 +1,38 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build linux darwin dragonfly freebsd netbsd openbsd solaris illumos + +package mariadb + +import ( + "testing" + "time" +) + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") + + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) + } + + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) + + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..ac84bca --- /dev/null +++ b/connection.go @@ -0,0 +1,650 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "io" + "net" + "strconv" + "strings" + "time" +) + +type mysqlConn struct { + buf buffer + netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. + affectedRows uint64 + insertId uint64 + cfg *Config + maxAllowedPacket int + maxWriteSize int + writeTimeout time.Duration + flags clientFlag + status statusFlag + sequence uint8 + parseTime bool + reset bool // set when the Go SQL package calls ResetSession + + // for context support (Go 1.8+) + watching bool + watcher chan<- context.Context + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed +} + +// Handles parameters set in DSN after the connection is established +func (mc *mysqlConn) handleParams() (err error) { + var cmdSet strings.Builder + for param, val := range mc.cfg.Params { + switch param { + // Charset: character_set_connection, character_set_client, character_set_results + case "charset": + charsets := strings.Split(val, ",") + for i := range charsets { + // ignore errors here - a charset may not exist + err = mc.exec("SET NAMES " + charsets[i]) + if err == nil { + break + } + } + if err != nil { + return + } + + // Other system vars accumulated in a single SET command + default: + if cmdSet.Len() == 0 { + // Heuristic: 29 chars for each other key=value to reduce reallocations + cmdSet.Grow(4 + len(param) + 1 + len(val) + 30*(len(mc.cfg.Params)-1)) + cmdSet.WriteString("SET ") + } else { + cmdSet.WriteByte(',') + } + cmdSet.WriteString(param) + cmdSet.WriteByte('=') + cmdSet.WriteString(val) + } + } + + if cmdSet.Len() > 0 { + err = mc.exec(cmdSet.String()) + if err != nil { + return + } + } + + return +} + +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + +func (mc *mysqlConn) Begin() (driver.Tx, error) { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) + if err == nil { + return &mysqlTx{mc}, err + } + return nil, mc.markBadConn(err) +} + +func (mc *mysqlConn) Close() (err error) { + // Makes Close idempotent + if !mc.closed.IsSet() { + err = mc.writeCommandPacket(comQuit) + } + + mc.cleanup() + + return +} + +// Closes the network connection and unsets internal variables. Do not call this +// function after successfully authentication, call Close instead. This function +// is called before auth or on auth failure because MySQL will have already +// closed the network connection. +func (mc *mysqlConn) cleanup() { + if !mc.closed.TrySet(true) { + return + } + + // Makes cleanup idempotent + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err + } + return ErrInvalidConn + } + return nil +} + +func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := mc.writeCommandPacketStr(comStmtPrepare, query) + if err != nil { + // STMT_PREPARE is safe to retry. So we can return ErrBadConn here. + errLog.Print(err) + return nil, driver.ErrBadConn + } + + stmt := &mysqlStmt{ + mc: mc, + } + + // Read Result + columnCount, err := stmt.readPrepareResultPacket() + if err == nil { + if stmt.paramCount > 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if columnCount > 0 { + err = mc.readUntilEOF() + } + } + + return stmt, err +} + +func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) { + // Number of ? should be same to len(args) + if strings.Count(query, "?") != len(args) { + return "", driver.ErrSkip + } + + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { + // can not take the buffer. Something must be wrong with the connection + errLog.Print(err) + return "", ErrInvalidConn + } + buf = buf[:0] + argPos := 0 + + for i := 0; i < len(query); i++ { + q := strings.IndexByte(query[i:], '?') + if q == -1 { + buf = append(buf, query[i:]...) + break + } + buf = append(buf, query[i:i+q]...) + i += q + + arg := args[argPos] + argPos++ + + if arg == nil { + buf = append(buf, "NULL"...) + continue + } + + switch v := arg.(type) { + case int64: + buf = strconv.AppendInt(buf, v, 10) + case uint64: + // Handle uint64 explicitly because our custom ConvertValue emits unsigned values + buf = strconv.AppendUint(buf, v, 10) + case float64: + buf = strconv.AppendFloat(buf, v, 'g', -1, 64) + case bool: + if v { + buf = append(buf, '1') + } else { + buf = append(buf, '0') + } + case time.Time: + if v.IsZero() { + buf = append(buf, "'0000-00-00'"...) + } else { + buf = append(buf, '\'') + buf, err = appendDateTime(buf, v.In(mc.cfg.Loc)) + if err != nil { + return "", err + } + buf = append(buf, '\'') + } + case json.RawMessage: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') + case []byte: + if v == nil { + buf = append(buf, "NULL"...) + } else { + buf = append(buf, "_binary'"...) + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeBytesBackslash(buf, v) + } else { + buf = escapeBytesQuotes(buf, v) + } + buf = append(buf, '\'') + } + case string: + buf = append(buf, '\'') + if mc.status&statusNoBackslashEscapes == 0 { + buf = escapeStringBackslash(buf, v) + } else { + buf = escapeStringQuotes(buf, v) + } + buf = append(buf, '\'') + default: + return "", driver.ErrSkip + } + + if len(buf)+4 > mc.maxAllowedPacket { + return "", driver.ErrSkip + } + } + if argPos != len(args) { + return "", driver.ErrSkip + } + return string(buf), nil +} + +func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + mc.affectedRows = 0 + mc.insertId = 0 + + err := mc.exec(query) + if err == nil { + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, err + } + return nil, mc.markBadConn(err) +} + +// Internal function to execute commands +func (mc *mysqlConn) exec(query string) error { + // Send command + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + return mc.markBadConn(err) + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } + + return mc.discardResults() +} + +func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + if len(args) != 0 { + if !mc.cfg.InterpolateParams { + return nil, driver.ErrSkip + } + // try client-side prepare to reduce roundtrip + prepared, err := mc.interpolateParams(query, args) + if err != nil { + return nil, err + } + query = prepared + } + // Send command + err := mc.writeCommandPacketStr(comQuery, query) + if err == nil { + // Read Result + var resLen int + resLen, err = mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + + if resLen == 0 { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } + } + + // Columns + rows.rs.columns, err = mc.readColumns(resLen) + return rows, err + } + } + return nil, mc.markBadConn(err) +} + +// Gets the value of the given MySQL System Variable +// The returned byte slice is only valid until the next read +func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { + // Send command + if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + return nil, err + } + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err == nil { + rows := new(textRows) + rows.mc = mc + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + + if resLen > 0 { + // Columns + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + dest := make([]driver.Value, resLen) + if err = rows.readRow(dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF() + } + } + return nil, err +} + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return mc.markBadConn(err) + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if mc.closed.IsSet() { + return nil, driver.ErrBadConn + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + // When watcher is not alive, can't watch it. + if mc.watcher == nil { + return nil + } + + mc.watching = true + mc.watcher <- ctx + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan context.Context, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + mc.reset = true + return nil +} + +// IsValid implements driver.Validator interface +// (From Go 1.15) +func (mc *mysqlConn) IsValid() bool { + return !mc.closed.IsSet() +} diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 0000000..589b97e --- /dev/null +++ b/connection_test.go @@ -0,0 +1,203 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "context" + "database/sql/driver" + "encoding/json" + "errors" + "net" + "testing" +) + +func TestInterpolateParams(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT 42+'gopher'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + +func TestInterpolateParamsJSONRawMessage(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + buf, err := json.Marshal(struct { + Value int `json:"value"` + }{Value: 42}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + q, err := mc.interpolateParams("SELECT ?", []driver.Value{json.RawMessage(buf)}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT '{\"value\":42}'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + +func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +// We don't support placeholder in string literal for now. +// https://github.com/go-sql-driver/mysql/pull/490 +func TestInterpolateParamsPlaceholderInString(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) + // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +func TestInterpolateParamsUint64(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)}) + if err != nil { + t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q) + } + if q != "SELECT 42" { + t.Errorf("Expected uint64 interpolation to work, got q=%#v", q) + } +} + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value) + } +} + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +} diff --git a/connector.go b/connector.go new file mode 100644 index 0000000..7a4f426 --- /dev/null +++ b/connector.go @@ -0,0 +1,146 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "context" + "database/sql/driver" + "net" +) + +type connector struct { + cfg *Config // immutable private copy. +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.cfg, + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + dctx := ctx + if mc.cfg.Timeout > 0 { + var cancel context.CancelFunc + dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout) + defer cancel() + } + mc.netConn, err = dial(dctx, mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } + + if err != nil { + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + if err := mc.watchCancel(ctx); err != nil { + mc.cleanup() + return nil, err + } + defer mc.finish() + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + if plugin == "" { + plugin = defaultAuthPlugin + } + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +// Driver implements driver.Connector interface. +// Driver returns &MySQLDriver{}. +func (c *connector) Driver() driver.Driver { + return &MySQLDriver{} +} diff --git a/connector_test.go b/connector_test.go new file mode 100644 index 0000000..360e7c5 --- /dev/null +++ b/connector_test.go @@ -0,0 +1,30 @@ +package mariadb + +import ( + "context" + "net" + "testing" + "time" +) + +func TestConnectorReturnsTimeout(t *testing.T) { + connector := &connector{&Config{ + Net: "tcp", + Addr: "1.1.1.1:1234", + Timeout: 10 * time.Millisecond, + }} + + _, err := connector.Connect(context.Background()) + if err == nil { + t.Fatal("error expected") + } + + if nerr, ok := err.(*net.OpError); ok { + expected := "dial tcp 1.1.1.1:1234: i/o timeout" + if nerr.Error() != expected { + t.Fatalf("expected %q, got %q", expected, nerr.Error()) + } + } else { + t.Fatalf("expected %T, got %T", nerr, err) + } +} diff --git a/const.go b/const.go new file mode 100644 index 0000000..a3c5823 --- /dev/null +++ b/const.go @@ -0,0 +1,174 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +const ( + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 + maxPacketSize = 1<<24 - 1 + timeFormat = "2006-01-02 15:04:05.999999" +) + +// MySQL constants documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +const ( + iOK byte = 0x00 + iAuthMoreData byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff +) + +// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags +type clientFlag uint32 + +const ( + clientLongPassword clientFlag = 1 << iota + clientFoundRows + clientLongFlag + clientConnectWithDB + clientNoSchema + clientCompress + clientODBC + clientLocalFiles + clientIgnoreSpace + clientProtocol41 + clientInteractive + clientSSL + clientIgnoreSIGPIPE + clientTransactions + clientReserved + clientSecureConn + clientMultiStatements + clientMultiResults + clientPSMultiResults + clientPluginAuth + clientConnectAttrs + clientPluginAuthLenEncClientData + clientCanHandleExpiredPasswords + clientSessionTrack + clientDeprecateEOF +) + +const ( + comQuit byte = iota + 1 + comInitDB + comQuery + comFieldList + comCreateDB + comDropDB + comRefresh + comShutdown + comStatistics + comProcessInfo + comConnect + comProcessKill + comDebug + comPing + comTime + comDelayedInsert + comChangeUser + comBinlogDump + comTableDump + comConnectOut + comRegisterSlave + comStmtPrepare + comStmtExecute + comStmtSendLongData + comStmtClose + comStmtReset + comSetOption + comStmtFetch +) + +// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + +const ( + fieldTypeDecimal fieldType = iota + fieldTypeTiny + fieldTypeShort + fieldTypeLong + fieldTypeFloat + fieldTypeDouble + fieldTypeNULL + fieldTypeTimestamp + fieldTypeLongLong + fieldTypeInt24 + fieldTypeDate + fieldTypeTime + fieldTypeDateTime + fieldTypeYear + fieldTypeNewDate + fieldTypeVarChar + fieldTypeBit +) +const ( + fieldTypeJSON fieldType = iota + 0xf5 + fieldTypeNewDecimal + fieldTypeEnum + fieldTypeSet + fieldTypeTinyBLOB + fieldTypeMediumBLOB + fieldTypeLongBLOB + fieldTypeBLOB + fieldTypeVarString + fieldTypeString + fieldTypeGeometry +) + +type fieldFlag uint16 + +const ( + flagNotNULL fieldFlag = 1 << iota + flagPriKey + flagUniqueKey + flagMultipleKey + flagBLOB + flagUnsigned + flagZeroFill + flagBinary + flagEnum + flagAutoIncrement + flagTimestamp + flagSet + flagUnknown1 + flagUnknown2 + flagUnknown3 + flagUnknown4 +) + +// http://dev.mysql.com/doc/internals/en/status-flags.html +type statusFlag uint16 + +const ( + statusInTrans statusFlag = 1 << iota + statusInAutocommit + statusReserved // Not in documentation + statusMoreResultsExists + statusNoGoodIndexUsed + statusNoIndexUsed + statusCursorExists + statusLastRowSent + statusDbDropped + statusNoBackslashEscapes + statusMetadataChanged + statusQueryWasSlow + statusPsOutParams + statusInTransReadonly + statusSessionStateChanged +) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/driver.go b/driver.go new file mode 100644 index 0000000..7321c14 --- /dev/null +++ b/driver.go @@ -0,0 +1,107 @@ +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package mysql provides a MySQL driver for Go's database/sql package. +// +// The driver should be used via the database/sql package: +// +// import "database/sql" +// import _ "github.com/go-sql-driver/mysql" +// +// db, err := sql.Open("mysql", "user:password@/dbname") +// +// See https://github.com/go-sql-driver/mysql#usage for details +package mariadb + +import ( + "context" + "database/sql" + "database/sql/driver" + "net" + "sync" +) + +// MySQLDriver is exported to make the driver directly accessible. +// In general the driver is used via the database/sql package. +type MySQLDriver struct{} + +// DialFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDial +// +// Deprecated: users should register a DialContextFunc instead +type DialFunc func(addr string) (net.Conn, error) + +// DialContextFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDialContext +type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) + +var ( + dialsLock sync.RWMutex + dials map[string]DialContextFunc +) + +// RegisterDialContext registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// The current context for the connection and its address is passed to the dial function. +func RegisterDialContext(net string, dial DialContextFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() + if dials == nil { + dials = make(map[string]DialContextFunc) + } + dials[net] = dial +} + +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +// +// Deprecated: users should call RegisterDialContext instead +func RegisterDial(network string, dial DialFunc) { + RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) { + return dial(addr) + }) +} + +// Open new Connection. +// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how +// the DSN string is formatted +func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + c := &connector{ + cfg: cfg, + } + return c.Connect(context.Background()) +} + +func init() { + sql.Register("mysql", &MySQLDriver{}) +} + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + cfg = cfg.Clone() + // normalize the contents of cfg so calls to NewConnector have the same + // behavior as MySQLDriver.OpenConnector + if err := cfg.normalize(); err != nil { + return nil, err + } + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..28ddda9 --- /dev/null +++ b/driver_test.go @@ -0,0 +1,3211 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "context" + "crypto/tls" + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "net/url" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Ensure that all the driver interfaces are implemented +var ( + _ driver.Rows = &binaryRows{} + _ driver.Rows = &textRows{} +) + +var ( + user string + pass string + prot string + addr string + dbname string + dsn string + netAddr string + available bool +) + +var ( + tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) + sDate = "2012-06-14" + tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) + sDateTime = "2011-11-20 21:27:37" + tDate0 = time.Time{} + sDate0 = "0000-00-00" + sDateTime0 = "0000-00-00 00:00:00" +) + +// See https://github.com/go-sql-driver/mysql/wiki/Testing +func init() { + // get environment variables + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + user = env("MYSQL_TEST_USER", "root") + pass = env("MYSQL_TEST_PASS", "") + prot = env("MYSQL_TEST_PROT", "tcp") + addr = env("MYSQL_TEST_ADDR", "localhost:3306") + dbname = env("MYSQL_TEST_DBNAME", "gotest") + netAddr = fmt.Sprintf("%s(%s)", prot, addr) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname) + c, err := net.Dial(prot, addr) + if err == nil { + available = true + c.Close() + } +} + +type DBTest struct { + *testing.T + db *sql.DB +} + +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + +func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn += "&multiStatements=true" + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } +} + +func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + db.Exec("DROP TABLE IF EXISTS test") + + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db2.Close() + } + + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db3.Close() + } + + dbt := &DBTest{t, db} + dbt2 := &DBTest{t, db2} + dbt3 := &DBTest{t, db3} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } + if db3 != nil { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } + } +} + +func (dbt *DBTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) +} + +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.Exec(query, args...) + if err != nil { + dbt.fail("exec", query, err) + } + return res +} + +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.Query(query, args...) + if err != nil { + dbt.fail("query", query, err) + } + return rows +} + +func maybeSkip(t *testing.T, err error, skipErrno uint16) { + mySQLErr, ok := err.(*MySQLError) + if !ok { + return + } + + if mySQLErr.Number == skipErrno { + t.Skipf("skipping test for error: %v", err) + } +} + +func TestEmptyQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // just a comment, no query + rows := dbt.mustQuery("--") + defer rows.Close() + // will hang before #255 + if rows.Next() { + dbt.Errorf("next on rows must be false") + } + }) +} + +func TestCRUD(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + + // Test for unexpected data + var out bool + rows := dbt.mustQuery("SELECT * FROM test") + if rows.Next() { + dbt.Error("unexpected data in empty table") + } + rows.Close() + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("expected InsertId 0, got %d", id) + } + + // Read + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if true != out { + dbt.Errorf("true != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Update + res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check Update + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if false != out { + dbt.Errorf("false != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Delete + res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check for unexpected rows + res = dbt.mustExec("DELETE FROM test") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 0 { + dbt.Fatalf("expected 0 affected row, got %d", count) + } + }) +} + +func TestMultiQuery(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Update + res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Read + var out int + rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + if rows.Next() { + rows.Scan(&out) + if 5 != out { + dbt.Errorf("5 != %d", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + }) +} + +func TestInt(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} + in := int64(42) + var out int64 + var rows *sql.Rows + + // SIGNED + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // UNSIGNED ZEROFILL + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s ZEROFILL: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat32(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + in := float32(42.23) + var out float32 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64Placeholder(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestString(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} + in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" + var out string + var rows *sql.Rows + + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // BLOB + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + + id := 2 + in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + + "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." + dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + + err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + if err != nil { + dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) + } else if out != in { + dbt.Errorf("BLOB: %s != %s", in, out) + } + }) +} + +func TestRawBytes(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := []byte("aaa") + v2 := []byte("bbb") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 sql.RawBytes + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + // https://github.com/go-sql-driver/mysql/issues/765 + // Appending to RawBytes shouldn't overwrite next RawBytes. + o1 = append(o1, "xyzzy"...) + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + +func TestRawMessage(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := json.RawMessage("{}") + v2 := json.RawMessage("[]") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 json.RawMessage + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + +type testValuer struct { + value string +} + +func (tv testValuer) Value() (driver.Value, error) { + return tv.value, nil +} + +func TestValuer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuer{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + +type testValuerWithValidation struct { + value string +} + +func (tv testValuerWithValidation) Value() (driver.Value, error) { + if len(tv.value) == 0 { + return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") + } + + return tv.value, nil +} + +func TestValuerWithValidation(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuerWithValidation{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM testValuer") + defer rows.Close() + + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + dbt.Errorf("Failed to check valuer error") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + dbt.Errorf("Failed to check nil") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + dbt.Errorf("Failed to check not valuer") + } + + dbt.mustExec("DROP TABLE IF EXISTS testValuer") + }) +} + +type timeTests struct { + dbtype string + tlayout string + tests []timeTest +} + +type timeTest struct { + s string // leading "!": do not use t as value in queries + t time.Time +} + +type timeMode byte + +func (t timeMode) String() string { + switch t { + case binaryString: + return "binary:string" + case binaryTime: + return "binary:time.Time" + case textString: + return "text:string" + } + panic("unsupported timeMode") +} + +func (t timeMode) Binary() bool { + switch t { + case binaryString, binaryTime: + return true + } + return false +} + +const ( + binaryString timeMode = iota + binaryTime + textString +) + +func (t timeTest) genQuery(dbtype string, mode timeMode) string { + var inner string + if mode.Binary() { + inner = "?" + } else { + inner = `"%s"` + } + return `SELECT cast(` + inner + ` as ` + dbtype + `)` +} + +func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { + var rows *sql.Rows + query := t.genQuery(dbtype, mode) + switch mode { + case binaryString: + rows = dbt.mustQuery(query, t.s) + case binaryTime: + rows = dbt.mustQuery(query, t.t) + case textString: + query = fmt.Sprintf(query, t.s) + rows = dbt.mustQuery(query) + default: + panic("unsupported mode") + } + defer rows.Close() + var err error + if !rows.Next() { + err = rows.Err() + if err == nil { + err = fmt.Errorf("no data") + } + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + var dst interface{} + err = rows.Scan(&dst) + if err != nil { + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + switch val := dst.(type) { + case []uint8: + str := string(val) + if str == t.s { + return + } + if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s { + // a fix mainly for TravisCI: + // accept full microsecond resolution in result for DATETIME columns + // where the binary protocol was used + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, str, + ) + case time.Time: + if val == t.t { + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, val.Format(tlayout), + ) + default: + fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) + dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", + dbtype, mode, + val, val, + ) + } +} + +func TestDateTime(t *testing.T) { + afterTime := func(t time.Time, d string) time.Time { + dur, err := time.ParseDuration(d) + if err != nil { + panic(err) + } + return t.Add(dur) + } + // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests + format := "2006-01-02 15:04:05.999999" + t0 := time.Time{} + tstr0 := "0000-00-00 00:00:00.000000" + testcases := []timeTests{ + {"DATE", format[:10], []timeTest{ + {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, + {t: t0, s: tstr0[:10]}, + }}, + {"DATETIME", format[:19], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(0)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(1)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, + {t: t0, s: tstr0[:21]}, + }}, + {"DATETIME(6)", format, []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, + {t: t0, s: tstr0}, + }}, + {"TIME", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(0)", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(1)", format[11:21], []timeTest{ + {t: afterTime(t0, "12345600ms")}, + {s: "!-12:34:56.7"}, + {s: "!-838:59:58.9"}, + {s: "!838:59:58.9"}, + {t: t0, s: tstr0[11:21]}, + }}, + {"TIME(6)", format[11:], []timeTest{ + {t: afterTime(t0, "1234567890123000ns")}, + {s: "!-12:34:56.789012"}, + {s: "!-838:59:58.999999"}, + {s: "!838:59:58.999999"}, + {t: t0, s: tstr0[11:]}, + }}, + } + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + microsecsSupported := false + zeroDateSupported := false + var rows *sql.Rows + var err error + rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) + if err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) + if err == nil { + rows.Scan(&zeroDateSupported) + rows.Close() + } + for _, setups := range testcases { + if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" { + // skip fractional second tests if unsupported by server + continue + } + for _, setup := range setups.tests { + allowBinTime := true + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } else if setup.s[0] == '!' { + // skip tests using setup.t as source in queries + allowBinTime = false + // fix setup.s - remove the "!" + setup.s = setup.s[1:] + } + if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] { + // skip disallowed 0000-00-00 date + continue + } + setup.run(dbt, setups.dbtype, setups.tlayout, textString) + setup.run(dbt, setups.dbtype, setups.tlayout, binaryString) + if allowBinTime { + setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime) + } + } + } + }) + } +} + +func TestTimestampMicros(t *testing.T) { + format := "2006-01-02 15:04:05.999999" + f0 := format[:19] + f1 := format[:21] + f6 := format[:26] + runTests(t, dsn, func(dbt *DBTest) { + // check if microseconds are supported. + // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width + // and not precision. + // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html + microsecsSupported := false + if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + if !microsecsSupported { + // skip test + return + } + _, err := dbt.db.Exec(` + CREATE TABLE test ( + value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', + value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', + value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' + )`, + ) + if err != nil { + dbt.Error(err) + } + defer dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + var res0, res1, res6 string + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Errorf("test contained no selectable values") + } + err = rows.Scan(&res0, &res1, &res6) + if err != nil { + dbt.Error(err) + } + if res0 != f0 { + dbt.Errorf("expected %q, got %q", f0, res0) + } + if res1 != f1 { + dbt.Errorf("expected %q, got %q", f1, res1) + } + if res6 != f6 { + dbt.Errorf("expected %q, got %q", f6, res6) + } + }) +} + +func TestNULL(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + nullStmt, err := dbt.db.Prepare("SELECT NULL") + if err != nil { + dbt.Fatal(err) + } + defer nullStmt.Close() + + nonNullStmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + defer nonNullStmt.Close() + + // NullBool + var nb sql.NullBool + // Invalid + if err = nullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if nb.Valid { + dbt.Error("valid NullBool which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if !nb.Valid { + dbt.Error("invalid NullBool which should be valid") + } else if nb.Bool != true { + dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) + } + + // NullFloat64 + var nf sql.NullFloat64 + // Invalid + if err = nullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if nf.Valid { + dbt.Error("valid NullFloat64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if !nf.Valid { + dbt.Error("invalid NullFloat64 which should be valid") + } else if nf.Float64 != float64(1) { + dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + } + + // NullInt64 + var ni sql.NullInt64 + // Invalid + if err = nullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if ni.Valid { + dbt.Error("valid NullInt64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if !ni.Valid { + dbt.Error("invalid NullInt64 which should be valid") + } else if ni.Int64 != int64(1) { + dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) + } + + // NullString + var ns sql.NullString + // Invalid + if err = nullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if ns.Valid { + dbt.Error("valid NullString which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if !ns.Valid { + dbt.Error("invalid NullString which should be valid") + } else if ns.String != `1` { + dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") + } + + // nil-bytes + var b []byte + // Read nil + if err = nullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil []byte which should be nil") + } + // Read non-nil + if err = nonNullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil []byte which should be non-nil") + } + // Insert nil + b = nil + success := false + if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + dbt.Fatal(err) + } + if !success { + dbt.Error("inserting []byte(nil) as NULL failed") + } + // Check input==output with input==nil + b = nil + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil echo from nil input") + } + // Check input==output with input!=nil + b = []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + + // Insert NULL + dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + + dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + + var out interface{} + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if out != nil { + dbt.Errorf("%v != nil", out) + } + } else { + dbt.Error("no data") + } + }) +} + +func TestUint64(t *testing.T) { + const ( + u0 = uint64(0) + uall = ^u0 + uhigh = uall >> 1 + utop = ^uhigh + s0 = int64(0) + sall = ^s0 + shigh = int64(uhigh) + stop = ^shigh + ) + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + row := stmt.QueryRow( + u0, uhigh, utop, uall, + s0, shigh, stop, sall, + ) + + var ua, ub, uc, ud uint64 + var sa, sb, sc, sd int64 + + err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd) + if err != nil { + dbt.Fatal(err) + } + switch { + case ua != u0, + ub != uhigh, + uc != utop, + ud != uall, + sa != s0, + sb != shigh, + sc != stop, + sd != sall: + dbt.Fatal("unexpected result value") + } + }) +} + +func TestLongData(t *testing.T) { + runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { + var maxAllowedPacketSize int + err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) + if err != nil { + dbt.Fatal(err) + } + maxAllowedPacketSize-- + + // don't get too ambitious + if maxAllowedPacketSize > 1<<25 { + maxAllowedPacketSize = 1 << 25 + } + + dbt.mustExec("CREATE TABLE test (value LONGBLOB)") + + in := strings.Repeat(`a`, maxAllowedPacketSize+1) + var out string + var rows *sql.Rows + + // Long text data + const nonDataQueryLen = 28 // length query w/o value + inS := in[:maxAllowedPacketSize-nonDataQueryLen] + dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if inS != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + dbt.Fatalf("LONGBLOB: no data") + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Long binary data + dbt.mustExec("INSERT INTO test VALUES(?)", in) + rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + if err = rows.Err(); err != nil { + dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error()) + } else { + dbt.Fatal("LONGBLOB: no data (err: <nil>)") + } + } + }) +} + +func TestLoadData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + verifyLoadDataResult := func() { + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + dbt.Fatal(err.Error()) + } + + i := 0 + values := [4]string{ + "a string", + "a string containing a \t", + "a string containing a \n", + "a string containing both \t\n", + } + + var id int + var value string + + for rows.Next() { + i++ + err = rows.Scan(&id, &value) + if err != nil { + dbt.Fatal(err.Error()) + } + if i != id { + dbt.Fatalf("%d != %d", i, id) + } + if values[i-1] != value { + dbt.Fatalf("%q != %q", values[i-1], value) + } + } + err = rows.Err() + if err != nil { + dbt.Fatal(err.Error()) + } + + if i != 4 { + dbt.Fatalf("rows count mismatch. Got %d, want 4", i) + } + } + + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") + + // Local File + file, err := ioutil.TempFile("", "gotest") + defer os.Remove(file.Name()) + if err != nil { + dbt.Fatal(err) + } + RegisterLocalFile(file.Name()) + + // Try first with empty file + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + var count int + err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) + if err != nil { + dbt.Fatal(err.Error()) + } + if count != 0 { + dbt.Fatalf("unexpected row count: got %d, want 0", count) + } + + // Then fille File with data and try to load it + file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") + file.Close() + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + verifyLoadDataResult() + + // Try with non-existing file + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent file didn't fail") + } else if err.Error() != "local file 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Reader + RegisterReaderHandler("test", func() io.Reader { + file, err = os.Open(file.Name()) + if err != nil { + dbt.Fatal(err) + } + return file + }) + dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent Reader didn't fail") + } else if err.Error() != "Reader 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + }) +} + +func TestFoundRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + }) + runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 matched rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 3 { + dbt.Fatalf("Expected 3 matched rows, got %d", count) + } + }) +} + +func TestTLS(t *testing.T) { + tlsTestReq := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + if err == ErrNoTLS { + dbt.Skip("server does not support TLS") + } else { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") + defer rows.Close() + + var variable, value *sql.RawBytes + for rows.Next() { + if err := rows.Scan(&variable, &value); err != nil { + dbt.Fatal(err.Error()) + } + + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) + } + } + } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) + + // Verify that registering / using a custom cfg works + RegisterTLSConfig("custom-skip-verify", &tls.Config{ + InsecureSkipVerify: true, + }) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) +} + +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + md := &MySQLDriver{} + conn, err := md.Open(dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("error closing connection: %s", err.Error()) + } + + defer func() { + if err := recover(); err != nil { + t.Errorf("panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != driver.ErrBadConn { + t.Errorf("unexpected error '%s', expected '%s'", + err.Error(), driver.ErrBadConn.Error()) + } +} + +func TestCharset(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() + + if !rows.Next() { + dbt.Fatalf("error getting connection charset: %s", rows.Err()) + } + + var got string + rows.Scan(&got) + + if got != expected { + dbt.Fatalf("expected connection charset %s but got %s", expected, got) + } + }) + } + + // non utf8 test + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8mb4", "utf8mb4") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8mb4", "ascii") + mustSetCharset("charset=utf8mb4,ascii", "utf8mb4") +} + +func TestFailingCharset(t *testing.T) { + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("connection must not succeed without a valid charset") + } + }) +} + +func TestCollation(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + defaultCollation := "utf8mb4_general_ci" + testCollations := []string{ + "", // do not set + defaultCollation, // driver default + "latin1_general_ci", + "binary", + "utf8mb4_unicode_ci", + "cp1257_bin", + } + + for _, collation := range testCollations { + var expected, tdsn string + if collation != "" { + tdsn = dsn + "&collation=" + collation + expected = collation + } else { + tdsn = dsn + expected = defaultCollation + } + + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) + } + + if got != expected { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } + }) + } +} + +func TestColumnsWithAlias(t *testing.T) { + runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1 AS A") + defer rows.Close() + cols, _ := rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A" { + t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) + } + + rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + defer rows.Close() + cols, _ = rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A.one" { + t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0]) + } + }) +} + +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") + } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + + // Regression test for timezone handling + tzTest := func(dbt *DBTest) { + // Create table + dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") + + // Insert local time into database (should be converted) + usCentral, _ := time.LoadLocation("US/Central") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + dbt.mustExec("INSERT INTO test VALUE (?)", reftime) + + // Retrieve time from DB + rows := dbt.mustQuery("SELECT ts FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Fatal("did not get any rows out") + } + + var dbTime time.Time + err := rows.Scan(&dbTime) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if reftime.Unix() != dbTime.Unix() { + dbt.Errorf("times do not match.\n") + dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(UTC)=%v\n", dbTime) + } + } + + for _, tz := range zones { + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + } +} + +// Special cases + +func TestRowsClose(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows, err := dbt.db.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + err = rows.Close() + if err != nil { + dbt.Fatal(err) + } + + if rows.Next() { + dbt.Fatal("unexpected row after rows.Close()") + } + + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + }) +} + +// dangling statements +// http://code.google.com/p/go/issues/detail?id=3865 +func TestCloseStmtBeforeRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + rows, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + + err = stmt.Close() + if err != nil { + dbt.Fatal(err) + } + + if !rows.Next() { + dbt.Fatal("getting row failed") + } else { + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + + var out bool + err = rows.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + }) +} + +// It is valid to have multiple Rows for the same Stmt +// http://code.google.com/p/go/issues/detail?id=3734 +func TestStmtMultiRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") + if err != nil { + dbt.Fatal(err) + } + + rows1, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows1.Close() + + rows2, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows2.Close() + + var out bool + + // 1 + if !rows1.Next() { + dbt.Fatal("first rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + if !rows2.Next() { + dbt.Fatal("first rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + // 2 + if !rows1.Next() { + dbt.Fatal("second rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows1.Next() { + dbt.Fatal("unexpected row on rows1") + } + err = rows1.Close() + if err != nil { + dbt.Fatal(err) + } + } + + if !rows2.Next() { + dbt.Fatal("second rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows2.Next() { + dbt.Fatal("unexpected row on rows2") + } + err = rows2.Close() + if err != nil { + dbt.Fatal(err) + } + } + }) +} + +// Regression test for +// * more than 32 NULL parameters (issue 209) +// * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) +func TestPreparedManyCols(t *testing.T) { + numParams := 65535 + runTests(t, dsn, func(dbt *DBTest) { + query := "SELECT ?" + strings.Repeat(",?", numParams-1) + stmt, err := dbt.db.Prepare(query) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + + // create more parameters than fit into the buffer + // which will take nil-values + params := make([]interface{}, numParams) + rows, err := stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + }) +} + +func TestConcurrent(t *testing.T) { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { + t.Skip("MYSQL_TEST_CONCURRENT env var not set") + } + + runTests(t, dsn, func(dbt *DBTest) { + var version string + if err := dbt.db.QueryRow("SELECT @@version").Scan(&version); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if strings.Contains(strings.ToLower(version), "mariadb") { + t.Skip(`TODO: "fix commands out of sync. Did you run multiple statements at once?" on MariaDB`) + } + + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + dbt.Logf("testing up to %d concurrent connections \r\n", max) + + var remaining, succeeded int32 = int32(max), 0 + + var wg sync.WaitGroup + wg.Add(max) + + var fatalError string + var once sync.Once + fatalf := func(s string, vals ...interface{}) { + once.Do(func() { + fatalError = fmt.Sprintf(s, vals...) + }) + } + + for i := 0; i < max; i++ { + go func(id int) { + defer wg.Done() + + tx, err := dbt.db.Begin() + atomic.AddInt32(&remaining, -1) + + if err != nil { + if err.Error() != "Error 1040: Too many connections" { + fatalf("error on conn %d: %s", id, err.Error()) + } + return + } + + // keep the connection busy until all connections are open + for remaining > 0 { + if _, err = tx.Exec("DO 1"); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + } + + if err = tx.Commit(); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + + // everything went fine with this connection + atomic.AddInt32(&succeeded, 1) + }(i) + } + + // wait until all conections are open + wg.Wait() + + if fatalError != "" { + dbt.Fatal(fatalError) + } + + dbt.Logf("reached %d concurrent connections\r\n", succeeded) + }) +} + +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, testErr) +} + +// Tests custom dial functions +func TestCustomDial(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // our custom dial function which justs wraps net.Dial here + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } +} + +func TestSQLInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that escapeQuotes and escapeBackslash are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by escapeQuotes and escapeBackslash + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} + +func TestUnixSocketAuthFail(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Save the current logger so we can restore it. + oldLogger := errLog + + // Set a new logger so we can capture its output. + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + newLogger := log.New(buffer, "prefix: ", 0) + SetLogger(newLogger) + + // Restore the logger. + defer SetLogger(oldLogger) + + // Make a new DSN that uses the MySQL socket file and a bad password, which + // we can make by simply appending any character to the real password. + badPass := pass + "x" + socket := "" + if prot == "unix" { + socket = addr + } else { + // Get socket file from MySQL. + err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) + if err != nil { + t.Fatalf("error on SELECT @@socket: %s", err.Error()) + } + } + t.Logf("socket: %s", socket) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) + db, err := sql.Open("mysql", badDSN) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + // Connect to MySQL for real. This will cause an auth failure. + err = db.Ping() + if err == nil { + t.Error("expected Ping() to return an error") + } + + // The driver should not log anything. + if actual := buffer.String(); actual != "" { + t.Errorf("expected no output, got %q", actual) + } + }) +} + +// See Issue #422 +func TestInterruptBySignal(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + DROP PROCEDURE IF EXISTS test_signal; + CREATE PROCEDURE test_signal(ret INT) + BEGIN + SELECT ret; + SIGNAL SQLSTATE + '45001' + SET + MESSAGE_TEXT = "an error", + MYSQL_ERRNO = 45001; + END + `) + defer dbt.mustExec("DROP PROCEDURE test_signal") + + var val int + + // text protocol + rows, err := dbt.db.Query("CALL test_signal(42)") + if err != nil { + dbt.Fatalf("error on text query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + + // binary protocol + rows, err = dbt.db.Query("CALL test_signal(?)", 42) + if err != nil { + dbt.Fatalf("error on binary query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + }) +} + +func TestColumnsReusesSlice(t *testing.T) { + rows := mysqlRows{ + rs: resultSet{ + columns: []mysqlField{ + { + tableName: "test", + name: "A", + }, + { + tableName: "test", + name: "B", + }, + }, + }, + } + + allocs := testing.AllocsPerRun(1, func() { + cols := rows.Columns() + + if len(cols) != 2 { + t.Fatalf("expected 2 columns, got %d", len(cols)) + } + }) + + if allocs != 0 { + t.Fatalf("expected 0 allocations, got %d", int(allocs)) + } + + if rows.rs.columnNames == nil { + t.Fatalf("expected columnNames to be set, got nil") + } +} + +func TestRejectReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read-only. We didn't set the `rejectReadOnly` + // option, so any writes after this should fail. + _, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY") + // Error 1193: Unknown system variable 'TRANSACTION' => skip test, + // MySQL server version is too old + maybeSkip(t, err, 1193) + if _, err := dbt.db.Exec("DROP TABLE test"); err == nil { + t.Fatalf("writing to DB in read-only session without " + + "rejectReadOnly did not error") + } + // Set the session back to read-write so runTests() can properly clean + // up the table `test`. + dbt.mustExec("SET SESSION TRANSACTION READ WRITE") + }) + + // Enable the `rejectReadOnly` option. + runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read only. Any writes after this should error on + // a driver.ErrBadConn, and cause `database/sql` to initiate a new + // connection. + dbt.mustExec("SET SESSION TRANSACTION READ ONLY") + // This would error, but `database/sql` should automatically retry on a + // new connection which is not read-only, and eventually succeed. + dbt.mustExec("DROP TABLE test") + }) +} + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + +// See Issue #799 +func TestEmptyPassword(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) + db, err := sql.Open("mysql", dsn) + if err == nil { + defer db.Close() + err = db.Ping() + } + + if pass == "" { + if err != nil { + t.Fatal(err.Error()) + } + } else { + if err == nil { + t.Fatal("expected authentication error") + } + if !strings.HasPrefix(err.Error(), "Error 1045") { + t.Fatal(err.Error()) + } + } +} + +// static interface implementation checks of mysqlConn +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +// static interface implementation checks of mysqlStmt +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +) + +// Ensure that all the driver interfaces are implemented +var ( + // _ driver.RowsColumnTypeLength = &binaryRows{} + // _ driver.RowsColumnTypeLength = &textRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &textRows{} + _ driver.RowsColumnTypeNullable = &binaryRows{} + _ driver.RowsColumnTypeNullable = &textRows{} + _ driver.RowsColumnTypePrecisionScale = &binaryRows{} + _ driver.RowsColumnTypePrecisionScale = &textRows{} + _ driver.RowsColumnTypeScanType = &binaryRows{} + _ driver.RowsColumnTypeScanType = &textRows{} + _ driver.RowsNextResultSet = &binaryRows{} + _ driver.RowsNextResultSet = &textRows{} +) + +func TestMultiResultSet(t *testing.T) { + type result struct { + values [][]int + columns []string + } + + // checkRows is a helper test function to validate rows containing 3 result + // sets with specific values and columns. The basic query would look like this: + // + // SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + // SELECT 0 UNION SELECT 1; + // SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + // + // to distinguish test cases the first string argument is put in front of + // every error or fatal message. + checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) { + expected := []result{ + { + values: [][]int{{1, 2}, {3, 4}}, + columns: []string{"col1", "col2"}, + }, + { + values: [][]int{{1, 2, 3}, {4, 5, 6}}, + columns: []string{"col1", "col2", "col3"}, + }, + } + + var res1 result + for rows.Next() { + var res [2]int + if err := rows.Scan(&res[0], &res[1]); err != nil { + dbt.Fatal(err) + } + res1.values = append(res1.values, res[:]) + } + + cols, err := rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res1.columns = cols + + if !reflect.DeepEqual(expected[0], res1) { + dbt.Error(desc, "want =", expected[0], "got =", res1) + } + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + // ignoring one result set + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + var res2 result + cols, err = rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res2.columns = cols + + for rows.Next() { + var res [3]int + if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + dbt.Fatal(desc, err) + } + res2.values = append(res2.values, res[:]) + } + + if !reflect.DeepEqual(expected[1], res2) { + dbt.Error(desc, "want =", expected[1], "got =", res2) + } + + if rows.NextResultSet() { + dbt.Error(desc, "unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(desc, err) + } + } + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery(`DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`) + defer rows.Close() + checkRows("query: ", rows, dbt) + }) + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + queries := []string{ + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + } + + defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss") + + for i, query := range queries { + dbt.mustExec(query) + + stmt, err := dbt.db.Prepare("CALL test_mrss()") + if err != nil { + dbt.Fatalf("%v (i=%d)", err, i) + } + defer stmt.Close() + + for j := 0; j < 2; j++ { + rows, err := stmt.Query() + if err != nil { + dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) + } + checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt) + } + } + }) +} + +func TestMultiResultSetNoSelect(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("DO 1; DO 2;") + defer rows.Close() + + if rows.Next() { + dbt.Error("unexpected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +// tests if rows are set in a proper state if some results were ignored before +// calling rows.NextResultSet. +func TestSkipResults(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1, 2") + defer rows.Close() + + if !rows.Next() { + dbt.Error("expected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +func TestPingContext(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := dbt.db.PingContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQueryRow(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + // the first row will be succeed. + var v int + if !rows.Next() { + dbt.Fatalf("unexpected end") + } + if err := rows.Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + + cancel() + // make sure the driver receives the cancel request. + time.Sleep(100 * time.Millisecond) + + if rows.Next() { + dbt.Errorf("expected end, but not") + } + if err := rows.Err(); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelPrepare(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelStmtExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.ExecContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelStmtQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.QueryContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelBegin(t *testing.T) { + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + t.Skip(`FIXME: it sometime fails with "expected driver.ErrBadConn, got sql: connection is already closed" on windows and macOS`) + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + conn, err := dbt.db.Conn(ctx) + if err != nil { + dbt.Fatal(err) + } + defer conn.Close() + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + dbt.Fatal(err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Transaction is canceled, so expect an error. + switch err := tx.Commit(); err { + case sql.ErrTxDone: + // because the transaction has already been rollbacked. + // the database/sql package watches ctx + // and rollbacks when ctx is canceled. + case context.Canceled: + // the database/sql package rollbacks on another goroutine, + // so the transaction may not be rollbacked depending on goroutine scheduling. + default: + dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) + } + + // The connection is now in an inoperable state - so performing other + // operations should fail with ErrBadConn + // Important to exercise isolation level too - it runs SET TRANSACTION ISOLATION + // LEVEL XXX first, which needs to return ErrBadConn if the connection's context + // is cancelled + _, err = conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelReadCommitted}) + if err != driver.ErrBadConn { + dbt.Errorf("expected driver.ErrBadConn, got %v", err) + } + + // cannot begin a transaction (on a different conn) with a canceled context + if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextBeginIsolationLevel(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + }) + if err != nil { + dbt.Fatal(err) + } + + tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + }) + if err != nil { + dbt.Fatal(err) + } + + _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if err != nil { + dbt.Fatal(err) + } + + var v int + row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Because writer transaction wasn't commited yet, it should be available + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + err = tx1.Commit() + if err != nil { + dbt.Fatal(err) + } + + row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Data written by writer transaction is already commited, it should be selectable + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + tx2.Commit() + }) +} + +func TestContextBeginReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: true, + }) + if _, ok := err.(*MySQLError); ok { + dbt.Skip("It seems that your MySQL does not support READ ONLY transactions") + return + } else if err != nil { + dbt.Fatal(err) + } + + // INSERT queries fail in a READ ONLY transaction. + _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if _, ok := err.(*MySQLError); !ok { + dbt.Errorf("expected MySQLError, got %v", err) + } + + // SELECT queries can be executed. + var v int + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + if err := tx.Commit(); err != nil { + dbt.Fatal(err) + } + }) +} + +func TestRowsColumnTypes(t *testing.T) { + niNULL := sql.NullInt64{Int64: 0, Valid: false} + ni0 := sql.NullInt64{Int64: 0, Valid: true} + ni1 := sql.NullInt64{Int64: 1, Valid: true} + ni42 := sql.NullInt64{Int64: 42, Valid: true} + nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} + nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} + nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} + nt0 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := sql.NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := sql.NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := sql.NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := sql.NullTime{Time: time.Time{}, Valid: false} + rbNULL := sql.RawBytes(nil) + rb0 := sql.RawBytes("0") + rb42 := sql.RawBytes("42") + rbTest := sql.RawBytes("Test") + rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 + rbx0 := sql.RawBytes("\x00") + rbx42 := sql.RawBytes("\x42") + + var columns = []struct { + name string + fieldType string // type used when creating table schema + databaseTypeName string // actual type used by MySQL + scanType reflect.Type + nullable bool + precision int64 // 0 if not ok + scale int64 + valuesIn [3]string + valuesOut [3]interface{} + }{ + {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "UNSIGNED TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "UNSIGNED SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "UNSIGNED BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "UNSIGNED INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, + } + + schema := "" + values1 := "" + values2 := "" + values3 := "" + for _, column := range columns { + schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType) + values1 += column.valuesIn[0] + ", " + values2 += column.valuesIn[1] + ", " + values3 += column.valuesIn[2] + ", " + } + schema = schema[:len(schema)-2] + values1 = values1[:len(values1)-2] + values2 = values2[:len(values2)-2] + values3 = values3[:len(values3)-2] + + runTests(t, dsn+"&parseTime=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } + + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] + + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } + + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } + + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) + } + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { + if !ok { + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) + } + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) + } + continue + } + } + + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) + } + } + } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) +} + +func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() + }) +} + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} + +var _ driver.DriverContext = &MySQLDriver{} + +type dialCtxKey struct{} + +func TestConnectorObeysDialTimeouts(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + if !ctx.Value(dialCtxKey{}).(bool) { + return nil, fmt.Errorf("test error: query context is not propagated to our dialer") + } + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + ctx := context.WithValue(context.Background(), dialCtxKey{}, true) + + _, err = db.ExecContext(ctx, "DO 1") + if err != nil { + t.Fatal(err) + } +} + +func configForTests(t *testing.T) *Config { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mycnf := NewConfig() + mycnf.User = user + mycnf.Passwd = pass + mycnf.Addr = addr + mycnf.Net = prot + mycnf.DBName = dbname + return mycnf +} + +func TestNewConnector(t *testing.T) { + mycnf := configForTests(t) + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +type slowConnection struct { + net.Conn + slowdown time.Duration +} + +func (sc *slowConnection) Read(b []byte) (int, error) { + time.Sleep(sc.slowdown) + return sc.Conn.Read(b) +} + +type connectorHijack struct { + driver.Connector + connErr error +} + +func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + conn, cw.connErr = cw.Connector.Connect(ctx) + return conn, cw.connErr +} + +func TestConnectorTimeoutsDuringOpen(t *testing.T) { + RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, prot, addr) + if err != nil { + return nil, err + } + return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil + }) + + mycnf := configForTests(t) + mycnf.Net = "slowconn" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + hijack := &connectorHijack{Connector: conn} + + db := sql.OpenDB(hijack) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = db.ExecContext(ctx, "DO 1") + if err != context.DeadlineExceeded { + t.Fatalf("ExecContext should have timed out") + } + if hijack.connErr != context.DeadlineExceeded { + t.Fatalf("(*Connector).Connect should have timed out") + } +} + +// A connection which can only be closed. +type dummyConnection struct { + net.Conn + closed bool +} + +func (d *dummyConnection) Close() error { + d.closed = true + return nil +} + +func TestConnectorTimeoutsWatchCancel(t *testing.T) { + var ( + cancel func() // Used to cancel the context just after connecting. + created *dummyConnection // The created connection. + ) + + RegisterDialContext("TestConnectorTimeoutsWatchCancel", func(ctx context.Context, addr string) (net.Conn, error) { + // Canceling at this time triggers the watchCancel error branch in Connect(). + cancel() + created = &dummyConnection{} + return created, nil + }) + + mycnf := NewConfig() + mycnf.User = "root" + mycnf.Addr = "foo" + mycnf.Net = "TestConnectorTimeoutsWatchCancel" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + var ctx context.Context + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + + if _, err := db.Conn(ctx); err != context.Canceled { + t.Errorf("got %v, want context.Canceled", err) + } + + if created == nil { + t.Fatal("no connection created") + } + if !created.closed { + t.Errorf("connection not closed") + } +} @@ -0,0 +1,559 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "errors" + "fmt" + "math/big" + "net" + "net/url" + "sort" + "strconv" + "strings" + "time" +) + +var ( + errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name") + errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") +) + +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. +type Config struct { + User string // Username + Passwd string // Password (requires User) + Net string // Network type + Addr string // Network address (requires Net) + DBName string // Database name + Params map[string]string // Connection parameters + Collation string // Connection collation + Loc *time.Location // Location for time.Time values + MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key + TLSConfig string // TLS configuration name + tls *tls.Config // TLS configuration + Timeout time.Duration // Dial timeout + ReadTimeout time.Duration // I/O read timeout + WriteTimeout time.Duration // I/O write timeout + + AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE + AllowCleartextPasswords bool // Allows the cleartext client side plugin + AllowNativePasswords bool // Allows the native password authentication method + AllowOldPasswords bool // Allows the old insecure password method + CheckConnLiveness bool // Check connections for liveness before using them + ClientFoundRows bool // Return number of matching rows instead of rows changed + ColumnsWithAlias bool // Prepend table alias to column names + InterpolateParams bool // Interpolate placeholders into query string + MultiStatements bool // Allow multiple statements in one query + ParseTime bool // Parse time values to time.Time + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + CheckConnLiveness: true, + } +} + +func (cfg *Config) Clone() *Config { + cp := *cfg + if cp.tls != nil { + cp.tls = cfg.tls.Clone() + } + if len(cp.Params) > 0 { + cp.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + cp.Params[k] = v + } + } + if cfg.pubKey != nil { + cp.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &cp +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.tls = &tls.Config{} + case "skip-verify", "preferred": + cfg.tls = &tls.Config{InsecureSkipVerify: true} + default: + cfg.tls = getTLSConfigClone(cfg.TLSConfig) + if cfg.tls == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } + } + + if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + + if cfg.ServerPubKey != "" { + cfg.pubKey = getServerPubKey(cfg.ServerPubKey) + if cfg.pubKey == nil { + return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey) + } + } + + return nil +} + +func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) { + buf.Grow(1 + len(name) + 1 + len(value)) + if !*hasParam { + *hasParam = true + buf.WriteByte('?') + } else { + buf.WriteByte('&') + } + buf.WriteString(name) + buf.WriteByte('=') + buf.WriteString(value) +} + +// FormatDSN formats the given Config into a DSN string which can be passed to +// the driver. +func (cfg *Config) FormatDSN() string { + var buf bytes.Buffer + + // [username[:password]@] + if len(cfg.User) > 0 { + buf.WriteString(cfg.User) + if len(cfg.Passwd) > 0 { + buf.WriteByte(':') + buf.WriteString(cfg.Passwd) + } + buf.WriteByte('@') + } + + // [protocol[(address)]] + if len(cfg.Net) > 0 { + buf.WriteString(cfg.Net) + if len(cfg.Addr) > 0 { + buf.WriteByte('(') + buf.WriteString(cfg.Addr) + buf.WriteByte(')') + } + } + + // /dbname + buf.WriteByte('/') + buf.WriteString(cfg.DBName) + + // [?param1=value1&...¶mN=valueN] + hasParam := false + + if cfg.AllowAllFiles { + hasParam = true + buf.WriteString("?allowAllFiles=true") + } + + if cfg.AllowCleartextPasswords { + writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true") + } + + if !cfg.AllowNativePasswords { + writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false") + } + + if cfg.AllowOldPasswords { + writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true") + } + + if !cfg.CheckConnLiveness { + writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false") + } + + if cfg.ClientFoundRows { + writeDSNParam(&buf, &hasParam, "clientFoundRows", "true") + } + + if col := cfg.Collation; col != defaultCollation && len(col) > 0 { + writeDSNParam(&buf, &hasParam, "collation", col) + } + + if cfg.ColumnsWithAlias { + writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true") + } + + if cfg.InterpolateParams { + writeDSNParam(&buf, &hasParam, "interpolateParams", "true") + } + + if cfg.Loc != time.UTC && cfg.Loc != nil { + writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String())) + } + + if cfg.MultiStatements { + writeDSNParam(&buf, &hasParam, "multiStatements", "true") + } + + if cfg.ParseTime { + writeDSNParam(&buf, &hasParam, "parseTime", "true") + } + + if cfg.ReadTimeout > 0 { + writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String()) + } + + if cfg.RejectReadOnly { + writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true") + } + + if len(cfg.ServerPubKey) > 0 { + writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey)) + } + + if cfg.Timeout > 0 { + writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String()) + } + + if len(cfg.TLSConfig) > 0 { + writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig)) + } + + if cfg.WriteTimeout > 0 { + writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String()) + } + + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { + writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket)) + } + + // other params + if cfg.Params != nil { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { + writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param])) + } + } + + return buf.String() +} + +// ParseDSN parses the DSN string to a Config +func ParseDSN(dsn string) (cfg *Config, err error) { + // New config with some default values + cfg = NewConfig() + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.Passwd = dsn[k+1 : j] + break + } + } + cfg.User = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.Addr = dsn[k+1 : i-1] + break + } + } + cfg.Net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.DBName = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + if err = cfg.normalize(); err != nil { + return nil, err + } + return +} + +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *Config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + // cfg params + switch value := param[1]; param[0] { + // Disable INFILE allowlist / enable all files + case "allowAllFiles": + var isBool bool + cfg.AllowAllFiles, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use cleartext authentication mode (MySQL 5.5.10+) + case "allowCleartextPasswords": + var isBool bool + cfg.AllowCleartextPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use native password authentication + case "allowNativePasswords": + var isBool bool + cfg.AllowNativePasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Use old authentication mode (pre MySQL 4.1) + case "allowOldPasswords": + var isBool bool + cfg.AllowOldPasswords, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Check connections for Liveness before using them + case "checkConnLiveness": + var isBool bool + cfg.CheckConnLiveness, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Switch "rowsAffected" mode + case "clientFoundRows": + var isBool bool + cfg.ClientFoundRows, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Collation + case "collation": + cfg.Collation = value + + case "columnsWithAlias": + var isBool bool + cfg.ColumnsWithAlias, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Compression + case "compress": + return errors.New("compression not implemented yet") + + // Enable client side placeholder substitution + case "interpolateParams": + var isBool bool + cfg.InterpolateParams, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Time Location + case "loc": + if value, err = url.QueryUnescape(value); err != nil { + return + } + cfg.Loc, err = time.LoadLocation(value) + if err != nil { + return + } + + // multiple statements in one query + case "multiStatements": + var isBool bool + cfg.MultiStatements, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // time.Time parsing + case "parseTime": + var isBool bool + cfg.ParseTime, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // I/O read Timeout + case "readTimeout": + cfg.ReadTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // Reject read-only connections + case "rejectReadOnly": + var isBool bool + cfg.RejectReadOnly, isBool = readBool(value) + if !isBool { + return errors.New("invalid bool value: " + value) + } + + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + cfg.ServerPubKey = name + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + + // Dial Timeout + case "timeout": + cfg.Timeout, err = time.ParseDuration(value) + if err != nil { + return + } + + // TLS-Encryption + case "tls": + boolValue, isBool := readBool(value) + if isBool { + if boolValue { + cfg.TLSConfig = "true" + } else { + cfg.TLSConfig = "false" + } + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { + cfg.TLSConfig = vl + } else { + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for TLS config name: %v", err) + } + cfg.TLSConfig = name + } + + // I/O write Timeout + case "writeTimeout": + cfg.WriteTimeout, err = time.ParseDuration(value) + if err != nil { + return + } + case "maxAllowedPacket": + cfg.MaxAllowedPacket, err = strconv.Atoi(value) + if err != nil { + return + } + default: + // lazy init + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + + if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + } + + return +} + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} diff --git a/dsn_test.go b/dsn_test.go new file mode 100644 index 0000000..18cda58 --- /dev/null +++ b/dsn_test.go @@ -0,0 +1,416 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "crypto/tls" + "fmt" + "net/url" + "reflect" + "testing" + "time" +) + +var testDSNs = []struct { + in string + out *Config +}{{ + "username:password@protocol(address)/dbname?param=value", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, ColumnsWithAlias: true, MultiStatements: true}, +}, { + "user@unix(/path/to/socket)/dbname?charset=utf8", + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "true"}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true, TLSConfig: "skip-verify"}, +}, { + "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, CheckConnLiveness: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, +}, { + "user:password@/dbname?allowNativePasswords=false&checkConnLiveness=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false, CheckConnLiveness: false}, +}, { + "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "@/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "user:p@/ssword@/", + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "unix/?arg=%2Fsome%2Fpath.ext", + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "tcp(127.0.0.1)/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, { + "tcp(de:ad:be:ef::ca:fe)/dbname", + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, CheckConnLiveness: true}, +}, +} + +func TestDSNParser(t *testing.T) { + for i, tst := range testDSNs { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + } + + // pointer not static + cfg.tls = nil + + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + } +} + +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr + "user:pass@tcp(127.0.0.1:3306)/db/name", // invalid dbname + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := ParseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func TestDSNReformat(t *testing.T) { + for i, tst := range testDSNs { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + continue + } + cfg1.tls = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) + + dsn2 := cfg1.FormatDSN() + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + continue + } + cfg2.tls = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + } +} + +func TestDSNServerPubKey(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + tst := baseDSN + "testKey" + cfg, err := ParseDSN(tst) + if err != nil { + t.Error(err.Error()) + } + + if cfg.ServerPubKey != "testKey" { + t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey) + } + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } + + // Key is missing + tst = baseDSN + "invalid_name" + cfg, err = ParseDSN(tst) + if err == nil { + t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } +} + +func TestDSNServerPubKeyQueryEscape(t *testing.T) { + const name = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name) + + RegisterServerPubKey(name, testPubKeyRSA) + defer DeregisterServerPubKey(name) + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } +} + +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + defer DeregisterTLSConfig("utils_test") + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := ParseDSN(tst) + if err == nil { + t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } else if tlsCfg.ServerName != "" { + t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) + } +} + +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + +func TestDSNWithCustomTLSQueryEscape(t *testing.T) { + const configKey = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) + name := "foohost" + tlsCfg := tls.Config{ServerName: name} + + RegisterTLSConfig(configKey, &tlsCfg) + defer DeregisterTLSConfig(configKey) + + cfg, err := ParseDSN(dsn) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) + } +} + +func TestDSNUnsafeCollation(t *testing.T) { + _, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } +} + +func TestParamsAreSorted(t *testing.T) { + expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo" + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.InterpolateParams = true + cfg.Params = map[string]string{ + "quux": "loo", + "foobar": "baz", + } + actual := cfg.FormatDSN() + if actual != expected { + t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual) + } +} + +func TestCloneConfig(t *testing.T) { + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey" + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err.Error()) + } + + cfg2 := cfg.Clone() + if cfg == cfg2 { + t.Errorf("Config.Clone did not create a separate config struct") + } + + if cfg2.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + cfg2.tls.ServerName = "example2.com" + if cfg.tls.ServerName == cfg2.tls.ServerName { + t.Errorf("changed cfg.tls.Server name should not propagate to original Config") + } + + if _, ok := cfg2.Params["foobar"]; !ok { + t.Errorf("cloned Config is missing custom params") + } + + delete(cfg2.Params, "foobar") + + if _, ok := cfg.Params["foobar"]; !ok { + t.Errorf("custom params in cloned Config should not propagate to original Config") + } + + if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) { + t.Errorf("public key in Config should be identical") + } +} + +func TestNormalizeTLSConfig(t *testing.T) { + tt := []struct { + tlsConfig string + want *tls.Config + }{ + {"", nil}, + {"false", nil}, + {"true", &tls.Config{ServerName: "myserver"}}, + {"skip-verify", &tls.Config{InsecureSkipVerify: true}}, + {"preferred", &tls.Config{InsecureSkipVerify: true}}, + {"test_tls_config", &tls.Config{ServerName: "myServerName"}}, + } + + RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"}) + defer func() { DeregisterTLSConfig("test_tls_config") }() + + for _, tc := range tt { + t.Run(tc.tlsConfig, func(t *testing.T) { + cfg := &Config{ + Addr: "myserver:3306", + TLSConfig: tc.tlsConfig, + } + + cfg.normalize() + + if cfg.tls == nil { + if tc.want != nil { + t.Fatal("wanted a tls config but got nil instead") + } + return + } + + if cfg.tls.ServerName != tc.want.ServerName { + t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", + tc.want.ServerName, cfg.tls.ServerName) + } + if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", + tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + } + }) + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := ParseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..9074d0b --- /dev/null +++ b/errors.go @@ -0,0 +1,77 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "errors" + "fmt" + "log" + "os" +) + +// Various errors the driver might return. Can change between driver versions. +var ( + ErrInvalidConn = errors.New("invalid connection") + ErrMalformPkt = errors.New("malformed packet") + ErrNoTLS = errors.New("TLS requested but server does not support TLS") + ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN") + ErrNativePassword = errors.New("this user requires mysql native password authentication.") + ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords") + ErrUnknownPlugin = errors.New("this authentication plugin is not supported") + ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+") + ErrPktSync = errors.New("commands out of sync. You can't run this command now") + ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") + ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") + ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") +) + +var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) + +// Logger is used to log critical error messages. +type Logger interface { + Print(v ...interface{}) +} + +// SetLogger is used to set the logger for critical errors. +// The initial logger is os.Stderr. +func SetLogger(logger Logger) error { + if logger == nil { + return errors.New("logger is nil") + } + errLog = logger + return nil +} + +// MySQLError is an error type which represents a single MySQL error +type MySQLError struct { + Number uint16 + SQLState [5]byte + Message string +} + +func (me *MySQLError) Error() string { + if me.SQLState != [5]byte{} { + return fmt.Sprintf("Error %d (%s): %s", me.Number, me.SQLState, me.Message) + } + + return fmt.Sprintf("Error %d: %s", me.Number, me.Message) +} + +func (me *MySQLError) Is(err error) bool { + if merr, ok := err.(*MySQLError); ok { + return merr.Number == me.Number + } + return false +} diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..816a4a9 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,61 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "errors" + "log" + "testing" +) + +func TestErrorsSetLogger(t *testing.T) { + previous := errLog + defer func() { + errLog = previous + }() + + // set up logger + const expected = "prefix: test\n" + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + logger := log.New(buffer, "prefix: ", 0) + + // print + SetLogger(logger) + errLog.Print("test") + + // check result + if actual := buffer.String(); actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } +} + +func TestErrorsStrictIgnoreNotes(t *testing.T) { + runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) { + dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") + }) +} + +func TestMySQLErrIs(t *testing.T) { + infraErr := &MySQLError{Number: 1234, Message: "the server is on fire"} + otherInfraErr := &MySQLError{Number: 1234, Message: "the datacenter is flooded"} + if !errors.Is(infraErr, otherInfraErr) { + t.Errorf("expected errors to be the same: %+v %+v", infraErr, otherInfraErr) + } + + differentCodeErr := &MySQLError{Number: 5678, Message: "the server is on fire"} + if errors.Is(infraErr, differentCodeErr) { + t.Fatalf("expected errors to be different: %+v %+v", infraErr, differentCodeErr) + } + + nonMysqlErr := errors.New("not a mysql error") + if errors.Is(infraErr, nonMysqlErr) { + t.Fatalf("expected errors to be different: %+v %+v", infraErr, nonMysqlErr) + } +} diff --git a/fields.go b/fields.go new file mode 100644 index 0000000..26328ca --- /dev/null +++ b/fields.go @@ -0,0 +1,206 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "database/sql" + "reflect" +) + +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED INT" + } + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED BIGINT" + } + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED SMALLINT" + } + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + if mf.flags&flagUnsigned != 0 { + return "UNSIGNED TINYINT" + } + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(sql.NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte + charSet uint8 +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} @@ -0,0 +1,24 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package. +// +// Copyright 2020 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build gofuzz + +package mariadb + +import ( + "database/sql" +) + +func Fuzz(data []byte) int { + db, err := sql.Open("mysql", string(data)) + if err != nil { + return 0 + } + db.Close() + return 1 +} @@ -0,0 +1,8 @@ +module snix.ir/mariadb + +go 1.18 + +require ( + golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 // indirect + golang.org/x/tools v0.1.10 // indirect +) @@ -0,0 +1,4 @@ +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0= +golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/tools v0.1.10 h1:QjFRCZxdOhBJ/UNgnBZLbNV13DlbnK0quyivTnXJM20= +golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= diff --git a/infile.go b/infile.go new file mode 100644 index 0000000..3b7c5e4 --- /dev/null +++ b/infile.go @@ -0,0 +1,184 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "fmt" + "io" + "os" + "strings" + "sync" +) + +var ( + fileRegister map[string]bool + fileRegisterLock sync.RWMutex + readerRegister map[string]func() io.Reader + readerRegisterLock sync.RWMutex +) + +// RegisterLocalFile adds the given file to the file allowlist, +// so that it can be used by "LOAD DATA LOCAL INFILE <filepath>". +// Alternatively you can allow the use of all local files with +// the DSN parameter 'allowAllFiles=true' +// +// filePath := "/home/gopher/data.csv" +// mysql.RegisterLocalFile(filePath) +// err := db.Exec("LOAD DATA LOCAL INFILE '" + filePath + "' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterLocalFile(filePath string) { + fileRegisterLock.Lock() + // lazy map init + if fileRegister == nil { + fileRegister = make(map[string]bool) + } + + fileRegister[strings.Trim(filePath, `"`)] = true + fileRegisterLock.Unlock() +} + +// DeregisterLocalFile removes the given filepath from the allowlist. +func DeregisterLocalFile(filePath string) { + fileRegisterLock.Lock() + delete(fileRegister, strings.Trim(filePath, `"`)) + fileRegisterLock.Unlock() +} + +// RegisterReaderHandler registers a handler function which is used +// to receive a io.Reader. +// The Reader can be used by "LOAD DATA LOCAL INFILE Reader::<name>". +// If the handler returns a io.ReadCloser Close() is called when the +// request is finished. +// +// mysql.RegisterReaderHandler("data", func() io.Reader { +// var csvReader io.Reader // Some Reader that returns CSV data +// ... // Open Reader here +// return csvReader +// }) +// err := db.Exec("LOAD DATA LOCAL INFILE 'Reader::data' INTO TABLE foo") +// if err != nil { +// ... +// +func RegisterReaderHandler(name string, handler func() io.Reader) { + readerRegisterLock.Lock() + // lazy map init + if readerRegister == nil { + readerRegister = make(map[string]func() io.Reader) + } + + readerRegister[name] = handler + readerRegisterLock.Unlock() +} + +// DeregisterReaderHandler removes the ReaderHandler function with +// the given name from the registry. +func DeregisterReaderHandler(name string) { + readerRegisterLock.Lock() + delete(readerRegister, name) + readerRegisterLock.Unlock() +} + +func deferredClose(err *error, closer io.Closer) { + closeErr := closer.Close() + if *err == nil { + *err = closeErr + } +} + +const defaultPacketSize = 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP + +func (mc *mysqlConn) handleInFileRequest(name string) (err error) { + var rdr io.Reader + var data []byte + packetSize := defaultPacketSize + if mc.maxWriteSize < packetSize { + packetSize = mc.maxWriteSize + } + + if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader + // The server might return an an absolute path. See issue #355. + name = name[idx+8:] + + readerRegisterLock.RLock() + handler, inMap := readerRegister[name] + readerRegisterLock.RUnlock() + + if inMap { + rdr = handler() + if rdr != nil { + if cl, ok := rdr.(io.Closer); ok { + defer deferredClose(&err, cl) + } + } else { + err = fmt.Errorf("Reader '%s' is <nil>", name) + } + } else { + err = fmt.Errorf("Reader '%s' is not registered", name) + } + } else { // File + name = strings.Trim(name, `"`) + fileRegisterLock.RLock() + fr := fileRegister[name] + fileRegisterLock.RUnlock() + if mc.cfg.AllowAllFiles || fr { + var file *os.File + var fi os.FileInfo + + if file, err = os.Open(name); err == nil { + defer deferredClose(&err, file) + + // get file size + if fi, err = file.Stat(); err == nil { + rdr = file + if fileSize := int(fi.Size()); fileSize < packetSize { + packetSize = fileSize + } + } + } + } else { + err = fmt.Errorf("local file '%s' is not registered", name) + } + } + + // send content packets + // if packetSize == 0, the Reader contains no data + if err == nil && packetSize > 0 { + data := make([]byte, 4+packetSize) + var n int + for err == nil { + n, err = rdr.Read(data[4:]) + if n > 0 { + if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + return ioErr + } + } + } + if err == io.EOF { + err = nil + } + } + + // send empty packet (termination) + if data == nil { + data = make([]byte, 4) + } + if ioErr := mc.writePacket(data[:4]); ioErr != nil { + return ioErr + } + + // read OK packet + if err == nil { + return mc.readResultOK() + } + + mc.readPacket() + return err +} diff --git a/nulltime.go b/nulltime.go new file mode 100644 index 0000000..1ed198d --- /dev/null +++ b/nulltime.go @@ -0,0 +1,71 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "time" +) + +// NullTime represents a time.Time that may be NULL. +// NullTime implements the Scanner interface so +// it can be used as a scan destination: +// +// var nt NullTime +// err := db.QueryRow("SELECT time FROM foo WHERE id=?", id).Scan(&nt) +// ... +// if nt.Valid { +// // use nt.Time +// } else { +// // NULL value +// } +// +// This NullTime implementation is not driver-specific +// +// Deprecated: NullTime doesn't honor the loc DSN parameter. +// NullTime.Scan interprets a time as UTC, not the loc DSN parameter. +// Use sql.NullTime instead. +type NullTime sql.NullTime + +// Scan implements the Scanner interface. +// The value type must be time.Time or string / []byte (formatted time-string), +// otherwise Scan fails. +func (nt *NullTime) Scan(value interface{}) (err error) { + if value == nil { + nt.Time, nt.Valid = time.Time{}, false + return + } + + switch v := value.(type) { + case time.Time: + nt.Time, nt.Valid = v, true + return + case []byte: + nt.Time, err = parseDateTime(v, time.UTC) + nt.Valid = (err == nil) + return + case string: + nt.Time, err = parseDateTime([]byte(v), time.UTC) + nt.Valid = (err == nil) + return + } + + nt.Valid = false + return fmt.Errorf("Can't convert %T to time.Time", value) +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} diff --git a/nulltime_test.go b/nulltime_test.go new file mode 100644 index 0000000..f0f8fd7 --- /dev/null +++ b/nulltime_test.go @@ -0,0 +1,62 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "database/sql" + "database/sql/driver" + "testing" + "time" +) + +var ( + // Check implementation of interfaces + _ driver.Valuer = NullTime{} + _ sql.Scanner = (*NullTime)(nil) +) + +func TestScanNullTime(t *testing.T) { + var scanTests = []struct { + in interface{} + error bool + valid bool + time time.Time + }{ + {tDate, false, true, tDate}, + {sDate, false, true, tDate}, + {[]byte(sDate), false, true, tDate}, + {tDateTime, false, true, tDateTime}, + {sDateTime, false, true, tDateTime}, + {[]byte(sDateTime), false, true, tDateTime}, + {tDate0, false, true, tDate0}, + {sDate0, false, true, tDate0}, + {[]byte(sDate0), false, true, tDate0}, + {sDateTime0, false, true, tDate0}, + {[]byte(sDateTime0), false, true, tDate0}, + {"", true, false, tDate0}, + {"1234", true, false, tDate0}, + {0, true, false, tDate0}, + } + + var nt = NullTime{} + var err error + + for _, tst := range scanTests { + err = nt.Scan(tst.in) + if (err != nil) != tst.error { + t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil)) + } + if nt.Valid != tst.valid { + t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid) + } + if nt.Time != tst.time { + t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time) + } + } +} diff --git a/packets.go b/packets.go new file mode 100644 index 0000000..8db1764 --- /dev/null +++ b/packets.go @@ -0,0 +1,1349 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "crypto/tls" + "database/sql/driver" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "math" + "time" +) + +// Packets documentation: +// http://dev.mysql.com/doc/internals/en/client-server-protocol.html + +// Read packet to buffer 'data' +func (mc *mysqlConn) readPacket() ([]byte, error) { + var prevData []byte + for { + // read packet header + data, err := mc.buf.readNext(4) + if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } + errLog.Print(err) + mc.Close() + return nil, ErrInvalidConn + } + + // packet length [24 bit] + pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) + + // check packet sync [8 bit] + if data[3] != mc.sequence { + if data[3] > mc.sequence { + return nil, ErrPktSyncMul + } + return nil, ErrPktSync + } + mc.sequence++ + + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)-1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, ErrInvalidConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] + data, err = mc.buf.readNext(pktLen) + if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } + errLog.Print(err) + mc.Close() + return nil, ErrInvalidConn + } + + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } + + return append(prevData, data...), nil + } + + prevData = append(prevData, data...) + } +} + +// Write packet buffer 'data' +func (mc *mysqlConn) writePacket(data []byte) error { + pktLen := len(data) - 4 + + if pktLen > mc.maxAllowedPacket { + return ErrPktTooLarge + } + + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + if mc.cfg.CheckConnLiveness { + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout)) + } + if err == nil { + err = connCheck(conn) + } + } + if err != nil { + errLog.Print("closing bad idle connection: ", err) + mc.Close() + return driver.ErrBadConn + } + } + + for { + var size int + if pktLen >= maxPacketSize { + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + size = maxPacketSize + } else { + data[0] = byte(pktLen) + data[1] = byte(pktLen >> 8) + data[2] = byte(pktLen >> 16) + size = pktLen + } + data[3] = mc.sequence + + // Write packet + if mc.writeTimeout > 0 { + if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + return err + } + } + + n, err := mc.netConn.Write(data[:4+size]) + if err == nil && n == 4+size { + mc.sequence++ + if size != maxPacketSize { + return nil + } + pktLen -= size + data = data[size:] + continue + } + + // Handle error + if err == nil { // n != len(data) + mc.cleanup() + errLog.Print(ErrMalformPkt) + } else { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } + mc.cleanup() + errLog.Print(err) + } + return ErrInvalidConn + } +} + +/****************************************************************************** +* Initialization Process * +******************************************************************************/ + +// Handshake Initialization Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() + if err != nil { + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, "", driver.ErrBadConn + } + return + } + + if data[0] == iERR { + return nil, "", mc.handleErrorPacket(data) + } + + // protocol version [1 byte] + if data[0] < minProtocolVersion { + return nil, "", fmt.Errorf( + "unsupported protocol version %d. Version %d or higher is required", + data[0], + minProtocolVersion, + ) + } + + // server version [null terminated string] + // connection id [4 bytes] + pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 + + // first part of the password cipher [8 bytes] + authData := data[pos : pos+8] + + // (filler) always 0x00 [1 byte] + pos += 8 + 1 + + // capability flags (lower 2 bytes) [2 bytes] + mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + if mc.flags&clientProtocol41 == 0 { + return nil, "", ErrOldProtocol + } + if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } + } + pos += 2 + + if len(data) > pos { + // character set [1 byte] + // status flags [2 bytes] + // capability flags (upper 2 bytes) [2 bytes] + // length of auth-plugin-data [1 byte] + // reserved (all [00]) [10 bytes] + pos += 1 + 2 + 2 + 1 + 10 + + // second part of the password cipher [mininum 13 bytes], + // where len=MAX(13, length of auth-plugin-data - 8) + // + // The web documentation is ambiguous about the length. However, + // according to mysql-5.7/sql/auth/sql_authentication.cc line 538, + // the 13th byte is "\0 byte, terminating the second part of + // a scramble". So the second part of the password cipher is + // a NULL terminated string that's at least 13 bytes with the + // last byte being NULL. + // + // The official Python library uses the fixed length 12 + // which seems to work but technically could have a hidden bug. + authData = append(authData, data[pos:pos+12]...) + pos += 13 + + // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + // \NUL otherwise + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + plugin = string(data[pos : pos+end]) + } else { + plugin = string(data[pos:]) + } + + // make a memory safe copy of the cipher slice + var b [20]byte + copy(b[:], authData) + return b[:], plugin, nil + } + + // make a memory safe copy of the cipher slice + var b [8]byte + copy(b[:], authData) + return b[:], plugin, nil +} + +// Client Authentication Packet +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { + // Adjust client flags based on server support + clientFlags := clientProtocol41 | + clientSecureConn | + clientLongPassword | + clientTransactions | + clientLocalFiles | + clientPluginAuth | + clientMultiResults | + mc.flags&clientLongFlag + + if mc.cfg.ClientFoundRows { + clientFlags |= clientFoundRows + } + + // To enable TLS / SSL + if mc.cfg.tls != nil { + clientFlags |= clientSSL + } + + if mc.cfg.MultiStatements { + clientFlags |= clientMultiStatements + } + + // encode length of the auth plugin data + var authRespLEIBuf [9]byte + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + clientFlags |= clientPluginAuthLenEncClientData + } + + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + + // To specify a db name + if n := len(mc.cfg.DBName); n > 0 { + clientFlags |= clientConnectWithDB + pktLen += n + 1 + } + + // Calculate packet length and get buffer with that size + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite + } + + // ClientFlags [32 bit] + data[4] = byte(clientFlags) + data[5] = byte(clientFlags >> 8) + data[6] = byte(clientFlags >> 16) + data[7] = byte(clientFlags >> 24) + + // MaxPacketSize [32 bit] (none) + data[8] = 0x00 + data[9] = 0x00 + data[10] = 0x00 + data[11] = 0x00 + + // Charset [1 byte] + var found bool + data[12], found = collations[mc.cfg.Collation] + if !found { + // Note possibility for false negatives: + // could be triggered although the collation is valid if the + // collations map does not contain entries the server supports. + return errors.New("unknown collation") + } + + // Filler [23 bytes] (all 0x00) + pos := 13 + for ; pos < 13+23; pos++ { + data[pos] = 0 + } + + // SSL Connection Request Packet + // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest + if mc.cfg.tls != nil { + // Send TLS / SSL request packet + if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + return err + } + + // Switch to TLS + tlsConn := tls.Client(mc.netConn, mc.cfg.tls) + if err := tlsConn.Handshake(); err != nil { + return err + } + mc.rawConn = mc.netConn + mc.netConn = tlsConn + mc.buf.nc = tlsConn + } + + // User [null terminated string] + if len(mc.cfg.User) > 0 { + pos += copy(data[pos:], mc.cfg.User) + } + data[pos] = 0x00 + pos++ + + // Auth Data [length encoded integer] + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], authResp) + + // Databasename [null terminated string] + if len(mc.cfg.DBName) > 0 { + pos += copy(data[pos:], mc.cfg.DBName) + data[pos] = 0x00 + pos++ + } + + pos += copy(data[pos:], plugin) + data[pos] = 0x00 + pos++ + + // Send Auth packet + return mc.writePacket(data[:pos]) +} + +// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { + pktLen := 4 + len(authData) + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite + } + + // Add the auth data [EOF] + copy(data[4:], authData) + return mc.writePacket(data) +} + +/****************************************************************************** +* Command Packets * +******************************************************************************/ + +func (mc *mysqlConn) writeCommandPacket(command byte) error { + // Reset Packet Sequence + mc.sequence = 0 + + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite + } + + // Add command byte + data[4] = command + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { + // Reset Packet Sequence + mc.sequence = 0 + + pktLen := 1 + len(arg) + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite + } + + // Add command byte + data[4] = command + + // Add arg + copy(data[5:], arg) + + // Send CMD packet + return mc.writePacket(data) +} + +func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { + // Reset Packet Sequence + mc.sequence = 0 + + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite + } + + // Add command byte + data[4] = command + + // Add arg [32 bit] + data[5] = byte(arg) + data[6] = byte(arg >> 8) + data[7] = byte(arg >> 16) + data[8] = byte(arg >> 24) + + // Send CMD packet + return mc.writePacket(data) +} + +/****************************************************************************** +* Result Packets * +******************************************************************************/ + +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { + data, err := mc.readPacket() + if err != nil { + return nil, "", err + } + + // packet indicator + switch data[0] { + + case iOK: + return nil, "", mc.handleOkPacket(data) + + case iAuthMoreData: + return data[1:], "", err + + case iEOF: + if len(data) == 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, "mysql_old_password", nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", ErrMalformPkt + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", mc.handleErrorPacket(data) + } +} + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() error { + data, err := mc.readPacket() + if err != nil { + return err + } + + if data[0] == iOK { + return mc.handleOkPacket(data) + } + return mc.handleErrorPacket(data) +} + +// Result Set Header Packet +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset +func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { + data, err := mc.readPacket() + if err == nil { + switch data[0] { + + case iOK: + return 0, mc.handleOkPacket(data) + + case iERR: + return 0, mc.handleErrorPacket(data) + + case iLocalInFile: + return 0, mc.handleInFileRequest(string(data[1:])) + } + + // column count + num, _, n := readLengthEncodedInteger(data) + if n-len(data) == 0 { + return int(num), nil + } + + return 0, ErrMalformPkt + } + return 0, err +} + +// Error Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet +func (mc *mysqlConn) handleErrorPacket(data []byte) error { + if data[0] != iERR { + return ErrMalformPkt + } + + // 0xff [1 byte] + + // Error Number [16 bit uint] + errno := binary.LittleEndian.Uint16(data[1:3]) + + // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // Oops; we are connected to a read-only connection, and won't be able + // to issue any write statements. Since RejectReadOnly is configured, + // we throw away this connection hoping this one would have write + // permission. This is specifically for a possible race condition + // during failover (e.g. on AWS Aurora). See README.md for more. + // + // We explicitly close the connection before returning + // driver.ErrBadConn to ensure that `database/sql` purges this + // connection and initiates a new one for next statement next time. + mc.Close() + return driver.ErrBadConn + } + + me := &MySQLError{Number: errno} + + pos := 3 + + // SQL State [optional: # + 5bytes string] + if data[3] == 0x23 { + copy(me.SQLState[:], data[4:4+5]) + pos = 9 + } + + // Error Message [string] + me.Message = string(data[pos:]) + + return me +} + +func readStatus(b []byte) statusFlag { + return statusFlag(b[0]) | statusFlag(b[1])<<8 +} + +// Ok Packet +// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +func (mc *mysqlConn) handleOkPacket(data []byte) error { + var n, m int + + // 0x00 [1 byte] + + // Affected rows [Length Coded Binary] + mc.affectedRows, _, n = readLengthEncodedInteger(data[1:]) + + // Insert id [Length Coded Binary] + mc.insertId, _, m = readLengthEncodedInteger(data[1+n:]) + + // server_status [2 bytes] + mc.status = readStatus(data[1+n+m : 1+n+m+2]) + if mc.status&statusMoreResultsExists != 0 { + return nil + } + + // warning count [2 bytes] + + return nil +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 +func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { + columns := make([]mysqlField, count) + + for i := 0; ; i++ { + data, err := mc.readPacket() + if err != nil { + return nil, err + } + + // EOF Packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) { + if i == count { + return columns, nil + } + return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns)) + } + + // Catalog + pos, err := skipLengthEncodedString(data) + if err != nil { + return nil, err + } + + // Database [len coded string] + n, err := skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Table [len coded string] + if mc.cfg.ColumnsWithAlias { + tableName, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + columns[i].tableName = string(tableName) + } else { + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + } + + // Original table [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Name [len coded string] + name, _, n, err := readLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + columns[i].name = string(name) + pos += n + + // Original name [len coded string] + n, err = skipLengthEncodedString(data[pos:]) + if err != nil { + return nil, err + } + pos += n + + // Filler [uint8] + pos++ + + // Charset [charset, collation uint8] + columns[i].charSet = data[pos] + pos += 2 + + // Length [uint32] + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 + + // Field type [uint8] + columns[i].fieldType = fieldType(data[pos]) + pos++ + + // Flags [uint16] + columns[i].flags = fieldFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) + pos += 2 + + // Decimals [uint8] + columns[i].decimals = data[pos] + //pos++ + + // Default value [len coded binary] + //if pos < len(data) { + // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) + //} + } +} + +// Read Packets as Field Packets until EOF-Packet or an Error appears +// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow +func (rows *textRows) readRow(dest []driver.Value) error { + mc := rows.mc + + if rows.rs.done { + return io.EOF + } + + data, err := mc.readPacket() + if err != nil { + return err + } + + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + // server_status [2 bytes] + rows.mc.status = readStatus(data[3:]) + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil + } + return io.EOF + } + if data[0] == iERR { + rows.mc = nil + return mc.handleErrorPacket(data) + } + + // RowSet Packet + var ( + n int + isNull bool + pos int = 0 + ) + + for i := range dest { + // Read bytes and convert to string + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + + if err != nil { + return err + } + + if isNull { + dest[i] = nil + continue + } + + if !mc.parseTime { + continue + } + + // Parse time field + switch rows.rs.columns[i].fieldType { + case fieldTypeTimestamp, + fieldTypeDateTime, + fieldTypeDate, + fieldTypeNewDate: + if dest[i], err = parseDateTime(dest[i].([]byte), mc.cfg.Loc); err != nil { + return err + } + } + } + + return nil +} + +// Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read +func (mc *mysqlConn) readUntilEOF() error { + for { + data, err := mc.readPacket() + if err != nil { + return err + } + + switch data[0] { + case iERR: + return mc.handleErrorPacket(data) + case iEOF: + if len(data) == 5 { + mc.status = readStatus(data[3:]) + } + return nil + } + } +} + +/****************************************************************************** +* Prepared Statements * +******************************************************************************/ + +// Prepare Result Packets +// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html +func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { + data, err := stmt.mc.readPacket() + if err == nil { + // packet indicator [1 byte] + if data[0] != iOK { + return 0, stmt.mc.handleErrorPacket(data) + } + + // statement id [4 bytes] + stmt.id = binary.LittleEndian.Uint32(data[1:5]) + + // Column count [16 bit uint] + columnCount := binary.LittleEndian.Uint16(data[5:7]) + + // Param count [16 bit uint] + stmt.paramCount = int(binary.LittleEndian.Uint16(data[7:9])) + + // Reserved [8 bit] + + // Warning count [16 bit uint] + + return columnCount, nil + } + return 0, err +} + +// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html +func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { + maxLen := stmt.mc.maxAllowedPacket - 1 + pktLen := maxLen + + // After the header (bytes 0-3) follows before the data: + // 1 byte command + // 4 bytes stmtID + // 2 bytes paramID + const dataOffset = 1 + 4 + 2 + + // Cannot use the write buffer since + // a) the buffer is too small + // b) it is in use + data := make([]byte, 4+1+4+2+len(arg)) + + copy(data[4+dataOffset:], arg) + + for argLen := len(arg); argLen > 0; argLen -= pktLen - dataOffset { + if dataOffset+argLen < maxLen { + pktLen = dataOffset + argLen + } + + stmt.mc.sequence = 0 + // Add command byte [1 byte] + data[4] = comStmtSendLongData + + // Add stmtID [32 bit] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // Add paramID [16 bit] + data[9] = byte(paramID) + data[10] = byte(paramID >> 8) + + // Send CMD packet + err := stmt.mc.writePacket(data[:4+pktLen]) + if err == nil { + data = data[pktLen-dataOffset:] + continue + } + return err + + } + + // Reset Packet Sequence + stmt.mc.sequence = 0 + return nil +} + +// Execute Prepared Statement +// http://dev.mysql.com/doc/internals/en/com-stmt-execute.html +func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { + if len(args) != stmt.paramCount { + return fmt.Errorf( + "argument count mismatch (got: %d; has: %d)", + len(args), + stmt.paramCount, + ) + } + + const minPktLen = 4 + 1 + 4 + 1 + 4 + mc := stmt.mc + + // Determine threshold dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + + // Reset packet-sequence + mc.sequence = 0 + + var data []byte + var err error + + if len(args) == 0 { + data, err = mc.buf.takeBuffer(minPktLen) + } else { + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. + } + if err != nil { + // cannot take the buffer. Something must be wrong with the connection + errLog.Print(err) + return errBadConnNoWrite + } + + // command [1 byte] + data[4] = comStmtExecute + + // statement_id [4 bytes] + data[5] = byte(stmt.id) + data[6] = byte(stmt.id >> 8) + data[7] = byte(stmt.id >> 16) + data[8] = byte(stmt.id >> 24) + + // flags (0: CURSOR_TYPE_NO_CURSOR) [1 byte] + data[9] = 0x00 + + // iteration_count (uint32(1)) [4 bytes] + data[10] = 0x01 + data[11] = 0x00 + data[12] = 0x00 + data[13] = 0x00 + + if len(args) > 0 { + pos := minPktLen + + var nullMask []byte + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { + // buffer has to be extended but we don't know by how much so + // we depend on append after all data with known sizes fit. + // We stop at that because we deal with a lot of columns here + // which makes the required allocation size hard to guess. + tmp := make([]byte, pos+maskLen+typesLen) + copy(tmp[:pos], data[:pos]) + data = tmp + nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. + pos += maskLen + } else { + nullMask = data[pos : pos+maskLen] + for i := range nullMask { + nullMask[i] = 0 + } + pos += maskLen + } + + // newParameterBoundFlag 1 [1 byte] + data[pos] = 0x01 + pos++ + + // type of each parameter [len(args)*2 bytes] + paramTypes := data[pos:] + pos += len(args) * 2 + + // value of each parameter [n bytes] + paramValues := data[pos:pos] + valuesCap := cap(paramValues) + + for i, arg := range args { + // build NULL-bitmap + if arg == nil { + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = byte(fieldTypeNULL) + paramTypes[i+i+1] = 0x00 + continue + } + + if v, ok := arg.(json.RawMessage); ok { + arg = []byte(v) + } + // cache types and values + switch v := arg.(type) { + case int64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case uint64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x80 // type is unsigned + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + + case float64: + paramTypes[i+i] = byte(fieldTypeDouble) + paramTypes[i+i+1] = 0x00 + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + math.Float64bits(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(math.Float64bits(v))..., + ) + } + + case bool: + paramTypes[i+i] = byte(fieldTypeTiny) + paramTypes[i+i+1] = 0x00 + + if v { + paramValues = append(paramValues, 0x01) + } else { + paramValues = append(paramValues, 0x00) + } + + case []byte: + // Common case (non-nil value) first + if v != nil { + paramTypes[i+i] = byte(fieldTypeString) + paramTypes[i+i+1] = 0x00 + + if len(v) < longDataSize { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, v); err != nil { + return err + } + } + continue + } + + // Handle []byte(nil) as a NULL value + nullMask[i/8] |= 1 << (uint(i) & 7) + paramTypes[i+i] = byte(fieldTypeNULL) + paramTypes[i+i+1] = 0x00 + + case string: + paramTypes[i+i] = byte(fieldTypeString) + paramTypes[i+i+1] = 0x00 + + if len(v) < longDataSize { + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(v)), + ) + paramValues = append(paramValues, v...) + } else { + if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + return err + } + } + + case time.Time: + paramTypes[i+i] = byte(fieldTypeString) + paramTypes[i+i+1] = 0x00 + + var a [64]byte + var b = a[:0] + + if v.IsZero() { + b = append(b, "0000-00-00"...) + } else { + b, err = appendDateTime(b, v.In(mc.cfg.Loc)) + if err != nil { + return err + } + } + + paramValues = appendLengthEncodedInteger(paramValues, + uint64(len(b)), + ) + paramValues = append(paramValues, b...) + + default: + return fmt.Errorf("cannot convert type: %T", arg) + } + } + + // Check if param values exceeded the available buffer + // In that case we must build the data packet with the new values buffer + if valuesCap != cap(paramValues) { + data = append(data[:pos], paramValues...) + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } + } + + pos += len(paramValues) + data = data[:pos] + } + + return mc.writePacket(data) +} + +func (mc *mysqlConn) discardResults() error { + for mc.status&statusMoreResultsExists != 0 { + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return err + } + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { + return err + } + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } + } + } + return nil +} + +// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html +func (rows *binaryRows) readRow(dest []driver.Value) error { + data, err := rows.mc.readPacket() + if err != nil { + return err + } + + // packet indicator [1 byte] + if data[0] != iOK { + // EOF Packet + if data[0] == iEOF && len(data) == 5 { + rows.mc.status = readStatus(data[3:]) + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil + } + return io.EOF + } + mc := rows.mc + rows.mc = nil + + // Error otherwise + return mc.handleErrorPacket(data) + } + + // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] + pos := 1 + (len(dest)+7+2)>>3 + nullMask := data[1:pos] + + for i := range dest { + // Field is NULL + // (byte >> bit-pos) % 2 == 1 + if ((nullMask[(i+2)>>3] >> uint((i+2)&7)) & 1) == 1 { + dest[i] = nil + continue + } + + // Convert to byte-coded string + switch rows.rs.columns[i].fieldType { + case fieldTypeNULL: + dest[i] = nil + continue + + // Numeric Types + case fieldTypeTiny: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(data[pos]) + } else { + dest[i] = int64(int8(data[pos])) + } + pos++ + continue + + case fieldTypeShort, fieldTypeYear: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) + } else { + dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) + } + pos += 2 + continue + + case fieldTypeInt24, fieldTypeLong: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) + } else { + dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) + } + pos += 4 + continue + + case fieldTypeLongLong: + if rows.rs.columns[i].flags&flagUnsigned != 0 { + val := binary.LittleEndian.Uint64(data[pos : pos+8]) + if val > math.MaxInt64 { + dest[i] = uint64ToString(val) + } else { + dest[i] = int64(val) + } + } else { + dest[i] = int64(binary.LittleEndian.Uint64(data[pos : pos+8])) + } + pos += 8 + continue + + case fieldTypeFloat: + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) + pos += 4 + continue + + case fieldTypeDouble: + dest[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[pos : pos+8])) + pos += 8 + continue + + // Length coded Binary Strings + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON: + var isNull bool + var n int + dest[i], isNull, n, err = readLengthEncodedString(data[pos:]) + pos += n + if err == nil { + if !isNull { + continue + } else { + dest[i] = nil + continue + } + } + return err + + case + fieldTypeDate, fieldTypeNewDate, // Date YYYY-MM-DD + fieldTypeTime, // Time [-][H]HH:MM:SS[.fractal] + fieldTypeTimestamp, fieldTypeDateTime: // Timestamp YYYY-MM-DD HH:MM:SS[.fractal] + + num, isNull, n := readLengthEncodedInteger(data[pos:]) + pos += n + + switch { + case isNull: + dest[i] = nil + continue + case rows.rs.columns[i].fieldType == fieldTypeTime: + // database/sql does not support an equivalent to TIME, return a string + var dstlen uint8 + switch decimals := rows.rs.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 8 + case 1, 2, 3, 4, 5, 6: + dstlen = 8 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.rs.columns[i].decimals, + ) + } + dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) + case rows.mc.parseTime: + dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) + default: + var dstlen uint8 + if rows.rs.columns[i].fieldType == fieldTypeDate { + dstlen = 10 + } else { + switch decimals := rows.rs.columns[i].decimals; decimals { + case 0x00, 0x1f: + dstlen = 19 + case 1, 2, 3, 4, 5, 6: + dstlen = 19 + 1 + decimals + default: + return fmt.Errorf( + "protocol error, illegal decimals value %d", + rows.rs.columns[i].decimals, + ) + } + } + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) + } + + if err == nil { + pos += int(num) + continue + } else { + return err + } + + // Please report if this happens! + default: + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) + } + } + + return nil +} diff --git a/packets_test.go b/packets_test.go new file mode 100644 index 0000000..159d018 --- /dev/null +++ b/packets_test.go @@ -0,0 +1,336 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "errors" + "net" + "testing" + "time" +) + +var ( + errConnClosed = errors.New("connection is closed") + errConnTooManyReads = errors.New("too many reads") + errConnTooManyWrites = errors.New("too many writes") +) + +// struct to mock a net.Conn for testing purposes +type mockConn struct { + laddr net.Addr + raddr net.Addr + data []byte + written []byte + queuedReplies [][]byte + closed bool + read int + reads int + writes int + maxReads int + maxWrites int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.reads++ + if m.maxReads > 0 && m.reads > m.maxReads { + return 0, errConnTooManyReads + } + + n = copy(b, m.data) + m.read += n + m.data = m.data[n:] + return +} +func (m *mockConn) Write(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.writes++ + if m.maxWrites > 0 && m.writes > m.maxWrites { + return 0, errConnTooManyWrites + } + + n = len(b) + m.written = append(m.written, b...) + + if n > 0 && len(m.queuedReplies) > 0 { + m.data = m.queuedReplies[0] + m.queuedReplies = m.queuedReplies[1:] + } + return +} +func (m *mockConn) Close() error { + m.closed = true + return nil +} +func (m *mockConn) LocalAddr() net.Addr { + return m.laddr +} +func (m *mockConn) RemoteAddr() net.Addr { + return m.raddr +} +func (m *mockConn) SetDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetReadDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// make sure mockConn implements the net.Conn interface +var _ net.Conn = new(mockConn) + +func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: NewConfig(), + netConn: conn, + closech: make(chan struct{}), + maxAllowedPacket: defaultMaxAllowedPacket, + sequence: sequence, + } + return conn, mc +} + +func TestReadPacketSingleByte(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 1 { + t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet)) + } + if packet[0] != 0xff { + t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0]) + } +} + +func TestReadPacketWrongSequenceID(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + // too low sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + mc.sequence = 1 + _, err := mc.readPacket() + if err != ErrPktSync { + t.Errorf("expected ErrPktSync, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // too high sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} + _, err = mc.readPacket() + if err != ErrPktSyncMul { + t.Errorf("expected ErrPktSyncMul, got %v", err) + } +} + +func TestReadPacketSplit(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + data := make([]byte, maxPacketSize*2+4*3) + const pkt2ofs = maxPacketSize + 4 + const pkt3ofs = 2 * (maxPacketSize + 4) + + // case 1: payload has length maxPacketSize + data = data[:pkt2ofs+4] + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet has payload length 0 and squence id 1 + // 00 00 00 01 + data[pkt2ofs+3] = 0x01 + + conn.data = data + conn.maxReads = 3 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize-1] != 0x22 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1]) + } + + // case 2: payload has length which is a multiple of maxPacketSize + data = data[:cap(data)] + + // 2nd packet now has maxPacketSize length + data[pkt2ofs] = 0xff + data[pkt2ofs+1] = 0xff + data[pkt2ofs+2] = 0xff + + // mark the payload start and end of the 2nd packet + data[pkt2ofs+4] = 0x33 + data[pkt2ofs+maxPacketSize+3] = 0x44 + + // 3rd packet has payload length 0 and squence id 2 + // 00 00 00 02 + data[pkt3ofs+3] = 0x02 + + conn.data = data + conn.reads = 0 + conn.maxReads = 5 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 2*maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[2*maxPacketSize-1] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1]) + } + + // case 3: payload has a length larger maxPacketSize, which is not an exact + // multiple of it + data = data[:pkt2ofs+4+42] + data[pkt2ofs] = 0x2a + data[pkt2ofs+1] = 0x00 + data[pkt2ofs+2] = 0x00 + data[pkt2ofs+4+41] = 0x44 + + conn.data = data + conn.reads = 0 + conn.maxReads = 4 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize+42 { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize+41] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41]) + } +} + +func TestReadPacketFail(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + closech: make(chan struct{}), + } + + // illegal empty (stand-alone) packet + conn.data = []byte{0x00, 0x00, 0x00, 0x00} + conn.maxReads = 1 + _, err := mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read header + conn.closed = true + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.closed = false + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read body + conn.maxReads = 1 + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } +} + +// https://github.com/go-sql-driver/mysql/pull/801 +// not-NUL terminated plugin_name in init packet +func TestRegression801(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: new(Config), + sequence: 42, + closech: make(chan struct{}), + } + + conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, + 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77, + 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, + 112, 97, 115, 115, 119, 111, 114, 100} + conn.maxReads = 1 + + authData, pluginName, err := mc.readHandshakePacket() + if err != nil { + t.Fatalf("got error: %v", err) + } + + if pluginName != "mysql_native_password" { + t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) + } + + expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114, + 47, 85, 75, 109, 99, 51, 77, 50, 64} + if !bytes.Equal(authData, expectedAuthData) { + t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData) + } +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..d64dcd1 --- /dev/null +++ b/result.go @@ -0,0 +1,22 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +type mysqlResult struct { + affectedRows int64 + insertId int64 +} + +func (res *mysqlResult) LastInsertId() (int64, error) { + return res.insertId, nil +} + +func (res *mysqlResult) RowsAffected() (int64, error) { + return res.affectedRows, nil +} @@ -0,0 +1,223 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "database/sql/driver" + "io" + "math" + "reflect" +) + +type resultSet struct { + columns []mysqlField + columnNames []string + done bool +} + +type mysqlRows struct { + mc *mysqlConn + rs resultSet + finish func() +} + +type binaryRows struct { + mysqlRows +} + +type textRows struct { + mysqlRows +} + +func (rows *mysqlRows) Columns() []string { + if rows.rs.columnNames != nil { + return rows.rs.columnNames + } + + columns := make([]string, len(rows.rs.columns)) + if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { + for i := range columns { + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name + } else { + columns[i] = rows.rs.columns[i].name + } + } + } else { + for i := range columns { + columns[i] = rows.rs.columns[i].name + } + } + + rows.rs.columnNames = columns + return columns +} + +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + return rows.rs.columns[i].typeDatabaseName() +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + +func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + + mc := rows.mc + if mc == nil { + return nil + } + if err := mc.error(); err != nil { + return err + } + + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + + // Remove unread packets from stream + if !rows.rs.done { + err = mc.readUntilEOF() + } + if err == nil { + if err = mc.discardResults(); err != nil { + return err + } + } + + rows.mc = nil + return err +} + +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if err := rows.mc.error(); err != nil { + return 0, err + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err + } + rows.rs.done = true + } + + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF + } + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} + +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *binaryRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} + +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + +func (rows *textRows) Next(dest []driver.Value) error { + if mc := rows.mc; mc != nil { + if err := mc.error(); err != nil { + return err + } + + // Fetch next row from stream + return rows.readRow(dest) + } + return io.EOF +} diff --git a/statement.go b/statement.go new file mode 100644 index 0000000..ca09247 --- /dev/null +++ b/statement.go @@ -0,0 +1,220 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "io" + "reflect" +) + +type mysqlStmt struct { + mc *mysqlConn + id uint32 + paramCount int +} + +func (stmt *mysqlStmt) Close() error { + if stmt.mc == nil || stmt.mc.closed.IsSet() { + // driver.Stmt.Close can be called more than once, thus this function + // has to be idempotent. + // See also Issue #450 and golang/go#16019. + //errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + stmt.mc = nil + return err +} + +func (stmt *mysqlStmt) NumInput() int { + return stmt.paramCount +} + +func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { + return converter{} +} + +func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } + + mc := stmt.mc + + mc.affectedRows = 0 + mc.insertId = 0 + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err + } + } + + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil +} + +func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // Send command + err := stmt.writeExecutePacket(args) + if err != nil { + return nil, stmt.mc.markBadConn(err) + } + + mc := stmt.mc + + // Read Result + resLen, err := mc.readResultSetHeaderPacket() + if err != nil { + return nil, err + } + + rows := new(binaryRows) + + if resLen > 0 { + rows.mc = mc + rows.rs.columns, err = mc.readColumns(resLen) + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } + } + + return rows, err +} + +var jsonType = reflect.TypeOf(json.RawMessage{}) + +type converter struct{} + +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. +func (c converter) ConvertValue(v interface{}) (driver.Value, error) { + if driver.IsValue(v) { + return v, nil + } + + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if driver.IsValue(sv) { + return sv, nil + } + // A value returend from the Valuer interface can be "a type handled by + // a database driver's NamedValueChecker interface" so we should accept + // uint64 here as well. + if u, ok := sv.(uint64); ok { + return u, nil + } + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Ptr: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint(), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + switch t := rv.Type(); { + case t == jsonType: + return v, nil + case t.Elem().Kind() == reflect.Uint8: + return rv.Bytes(), nil + default: + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind()) + } + case reflect.String: + return rv.String(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/statement_test.go b/statement_test.go new file mode 100644 index 0000000..01405cf --- /dev/null +++ b/statement_test.go @@ -0,0 +1,151 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "testing" +) + +func TestConvertDerivedString(t *testing.T) { + type derived string + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if !bytes.Equal(output.([]byte), []byte("value")) { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +type myUint64 struct { + value uint64 +} + +func (u myUint64) Value() (driver.Value, error) { + return u.value, nil +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + myUint64{uint64(42)}, + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != uint64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) + } +} + +func TestConvertJSON(t *testing.T) { + raw := json.RawMessage("{}") + + out, err := converter{}.ConvertValue(raw) + + if err != nil { + t.Fatal("json.RawMessage was failed in convert", err) + } + + if _, ok := out.(json.RawMessage); !ok { + t.Fatalf("json.RawMessage converted, got %#v %T", out, out) + } +} diff --git a/transaction.go b/transaction.go new file mode 100644 index 0000000..ed0b5e1 --- /dev/null +++ b/transaction.go @@ -0,0 +1,31 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +type mysqlTx struct { + mc *mysqlConn +} + +func (tx *mysqlTx) Commit() (err error) { + if tx.mc == nil || tx.mc.closed.IsSet() { + return ErrInvalidConn + } + err = tx.mc.exec("COMMIT") + tx.mc = nil + return +} + +func (tx *mysqlTx) Rollback() (err error) { + if tx.mc == nil || tx.mc.closed.IsSet() { + return ErrInvalidConn + } + err = tx.mc.exec("ROLLBACK") + tx.mc = nil + return +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..fbbf49d --- /dev/null +++ b/utils.go @@ -0,0 +1,874 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "crypto/tls" + "database/sql" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "io" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Registry for custom tls.Configs +var ( + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config +) + +// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. +// Use the key as a value in the DSN where tls=value. +// +// Note: The provided tls.Config is exclusively owned by the driver after +// registering it. +// +// rootCertPool := x509.NewCertPool() +// pem, err := ioutil.ReadFile("/path/ca-cert.pem") +// if err != nil { +// log.Fatal(err) +// } +// if ok := rootCertPool.AppendCertsFromPEM(pem); !ok { +// log.Fatal("Failed to append PEM.") +// } +// clientCert := make([]tls.Certificate, 0, 1) +// certs, err := tls.LoadX509KeyPair("/path/client-cert.pem", "/path/client-key.pem") +// if err != nil { +// log.Fatal(err) +// } +// clientCert = append(clientCert, certs) +// mysql.RegisterTLSConfig("custom", &tls.Config{ +// RootCAs: rootCertPool, +// Certificates: clientCert, +// }) +// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") +// +func RegisterTLSConfig(key string, config *tls.Config) error { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { + return fmt.Errorf("key '%s' is reserved", key) + } + + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) + } + + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() + return nil +} + +// DeregisterTLSConfig removes the tls.Config associated with key. +func DeregisterTLSConfig(key string) { + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) + } + tlsConfigLock.Unlock() +} + +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = v.Clone() + } + tlsConfigLock.RUnlock() + return +} + +// Returns the bool value of the input. +// The 2nd return value indicates if the input was a valid bool value +func readBool(input string) (value bool, valid bool) { + switch input { + case "1", "true", "TRUE", "True": + return true, true + case "0", "false", "FALSE", "False": + return false, true + } + + // Not a valid bool value + return +} + +/****************************************************************************** +* Time related utils * +******************************************************************************/ + +func parseDateTime(b []byte, loc *time.Location) (time.Time, error) { + const base = "0000-00-00 00:00:00.000000" + switch len(b) { + case 10, 19, 21, 22, 23, 24, 25, 26: // up to "YYYY-MM-DD HH:MM:SS.MMMMMM" + if string(b) == base[:len(b)] { + return time.Time{}, nil + } + + year, err := parseByteYear(b) + if err != nil { + return time.Time{}, err + } + if year <= 0 { + year = 1 + } + + if b[4] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[4]) + } + + m, err := parseByte2Digits(b[5], b[6]) + if err != nil { + return time.Time{}, err + } + if m <= 0 { + m = 1 + } + month := time.Month(m) + + if b[7] != '-' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[7]) + } + + day, err := parseByte2Digits(b[8], b[9]) + if err != nil { + return time.Time{}, err + } + if day <= 0 { + day = 1 + } + if len(b) == 10 { + return time.Date(year, month, day, 0, 0, 0, 0, loc), nil + } + + if b[10] != ' ' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[10]) + } + + hour, err := parseByte2Digits(b[11], b[12]) + if err != nil { + return time.Time{}, err + } + if b[13] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[13]) + } + + min, err := parseByte2Digits(b[14], b[15]) + if err != nil { + return time.Time{}, err + } + if b[16] != ':' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[16]) + } + + sec, err := parseByte2Digits(b[17], b[18]) + if err != nil { + return time.Time{}, err + } + if len(b) == 19 { + return time.Date(year, month, day, hour, min, sec, 0, loc), nil + } + + if b[19] != '.' { + return time.Time{}, fmt.Errorf("bad value for field: `%c`", b[19]) + } + nsec, err := parseByteNanoSec(b[20:]) + if err != nil { + return time.Time{}, err + } + return time.Date(year, month, day, hour, min, sec, nsec, loc), nil + default: + return time.Time{}, fmt.Errorf("invalid time bytes: %s", b) + } +} + +func parseByteYear(b []byte) (int, error) { + year, n := 0, 1000 + for i := 0; i < 4; i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + year += v * n + n /= 10 + } + return year, nil +} + +func parseByte2Digits(b1, b2 byte) (int, error) { + d1, err := bToi(b1) + if err != nil { + return 0, err + } + d2, err := bToi(b2) + if err != nil { + return 0, err + } + return d1*10 + d2, nil +} + +func parseByteNanoSec(b []byte) (int, error) { + ns, digit := 0, 100000 // max is 6-digits + for i := 0; i < len(b); i++ { + v, err := bToi(b[i]) + if err != nil { + return 0, err + } + ns += v * digit + digit /= 10 + } + // nanoseconds has 10-digits. (needs to scale digits) + // 10 - 6 = 4, so we have to multiple 1000. + return ns * 1000, nil +} + +func bToi(b byte) (int, error) { + if b < '0' || b > '9' { + return 0, errors.New("not [0-9]") + } + return int(b - '0'), nil +} + +func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Value, error) { + switch num { + case 0: + return time.Time{}, nil + case 4: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + 0, 0, 0, 0, + loc, + ), nil + case 7: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + int(data[4]), // hour + int(data[5]), // minutes + int(data[6]), // seconds + 0, + loc, + ), nil + case 11: + return time.Date( + int(binary.LittleEndian.Uint16(data[:2])), // year + time.Month(data[2]), // month + int(data[3]), // day + int(data[4]), // hour + int(data[5]), // minutes + int(data[6]), // seconds + int(binary.LittleEndian.Uint32(data[7:11]))*1000, // nanoseconds + loc, + ), nil + } + return nil, fmt.Errorf("invalid DATETIME packet length %d", num) +} + +func appendDateTime(buf []byte, t time.Time) ([]byte, error) { + year, month, day := t.Date() + hour, min, sec := t.Clock() + nsec := t.Nanosecond() + + if year < 1 || year > 9999 { + return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap + } + year100 := year / 100 + year1 := year % 100 + + var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape + localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1] + localBuf[4] = '-' + localBuf[5], localBuf[6] = digits10[month], digits01[month] + localBuf[7] = '-' + localBuf[8], localBuf[9] = digits10[day], digits01[day] + + if hour == 0 && min == 0 && sec == 0 && nsec == 0 { + return append(buf, localBuf[:10]...), nil + } + + localBuf[10] = ' ' + localBuf[11], localBuf[12] = digits10[hour], digits01[hour] + localBuf[13] = ':' + localBuf[14], localBuf[15] = digits10[min], digits01[min] + localBuf[16] = ':' + localBuf[17], localBuf[18] = digits10[sec], digits01[sec] + + if nsec == 0 { + return append(buf, localBuf[:19]...), nil + } + nsec100000000 := nsec / 100000000 + nsec1000000 := (nsec / 1000000) % 100 + nsec10000 := (nsec / 10000) % 100 + nsec100 := (nsec / 100) % 100 + nsec1 := nsec % 100 + localBuf[19] = '.' + + // milli second + localBuf[20], localBuf[21], localBuf[22] = + digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000] + // micro second + localBuf[23], localBuf[24], localBuf[25] = + digits10[nsec10000], digits01[nsec10000], digits10[nsec100] + // nano second + localBuf[26], localBuf[27], localBuf[28] = + digits01[nsec100], digits10[nsec1], digits01[nsec1] + + // trim trailing zeros + n := len(localBuf) + for n > 0 && localBuf[n-1] == '0' { + n-- + } + + return append(buf, localBuf[:n]...), nil +} + +// zeroDateTime is used in formatBinaryDateTime to avoid an allocation +// if the DATE or DATETIME has the zero value. +// It must never be changed. +// The current behavior depends on database/sql copying the result. +var zeroDateTime = []byte("0000-00-00 00:00:00.000000") + +const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" + +func appendMicrosecs(dst, src []byte, decimals int) []byte { + if decimals <= 0 { + return dst + } + if len(src) == 0 { + return append(dst, ".000000"[:decimals+1]...) + } + + microsecs := binary.LittleEndian.Uint32(src[:4]) + p1 := byte(microsecs / 10000) + microsecs -= 10000 * uint32(p1) + p2 := byte(microsecs / 100) + microsecs -= 100 * uint32(p2) + p3 := byte(microsecs) + + switch decimals { + default: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], digits01[p3], + ) + case 1: + return append(dst, '.', + digits10[p1], + ) + case 2: + return append(dst, '.', + digits10[p1], digits01[p1], + ) + case 3: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], + ) + case 4: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + ) + case 5: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], + ) + } +} + +func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + return zeroDateTime[:length], nil + } + var dst []byte // return value + var p1, p2, p3 byte // current digit pair + + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s length %d", t, length) + } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt := year / 100 + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + + // p1 is 2-digit hour, src is after hour + p2, p3 = src[0], src[1] + dst = append(dst, + digits10[p1], digits01[p1], ':', + digits10[p2], digits01[p2], ':', + digits10[p3], digits01[p3], + ) + return appendMicrosecs(dst, src[2:], int(length)-20), nil +} + +func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed + if len(src) == 0 { + return zeroDateTime[11 : 11+length], nil + } + var dst []byte // return value + + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds + default: + return nil, fmt.Errorf("illegal TIME length %d", length) + } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + days := binary.LittleEndian.Uint32(src[1:5]) + hours := int64(days)*24 + int64(src[5]) + + if hours >= 100 { + dst = strconv.AppendInt(dst, hours, 10) + } else { + dst = append(dst, digits10[hours], digits01[hours]) + } + + min, sec := src[6], src[7] + dst = append(dst, ':', + digits10[min], digits01[min], ':', + digits10[sec], digits01[sec], + ) + return appendMicrosecs(dst, src[8:], int(length)-9), nil +} + +/****************************************************************************** +* Convert from and to bytes * +******************************************************************************/ + +func uint64ToBytes(n uint64) []byte { + return []byte{ + byte(n), + byte(n >> 8), + byte(n >> 16), + byte(n >> 24), + byte(n >> 32), + byte(n >> 40), + byte(n >> 48), + byte(n >> 56), + } +} + +func uint64ToString(n uint64) []byte { + var a [20]byte + i := 20 + + // U+0030 = 0 + // ... + // U+0039 = 9 + + var q uint64 + for n >= 10 { + i-- + q = n / 10 + a[i] = uint8(n-q*10) + 0x30 + n = q + } + + i-- + a[i] = uint8(n) + 0x30 + + return a[i:] +} + +// treats string value as unsigned integer representation +func stringToInt(b []byte) int { + val := 0 + for i := range b { + val *= 10 + val += int(b[i] - 0x30) + } + return val +} + +// returns the string read as a bytes slice, whether the value is NULL, +// the number of bytes read and an error, in case the string is longer than +// the input slice +func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { + // Get length + num, isNull, n := readLengthEncodedInteger(b) + if num < 1 { + return b[n:n], isNull, n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return b[n-int(num) : n : n], false, n, nil + } + return nil, false, n, io.EOF +} + +// returns the number of bytes skipped and an error, in case the string is +// longer than the input slice +func skipLengthEncodedString(b []byte) (int, error) { + // Get length + num, _, n := readLengthEncodedInteger(b) + if num < 1 { + return n, nil + } + + n += int(num) + + // Check data length + if len(b) >= n { + return n, nil + } + return n, io.EOF +} + +// returns the number read, whether the value is NULL and the number of bytes read +func readLengthEncodedInteger(b []byte) (uint64, bool, int) { + // See issue #349 + if len(b) == 0 { + return 0, true, 1 + } + + switch b[0] { + // 251: NULL + case 0xfb: + return 0, true, 1 + + // 252: value of following 2 + case 0xfc: + return uint64(b[1]) | uint64(b[2])<<8, false, 3 + + // 253: value of following 3 + case 0xfd: + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16, false, 4 + + // 254: value of following 8 + case 0xfe: + return uint64(b[1]) | uint64(b[2])<<8 | uint64(b[3])<<16 | + uint64(b[4])<<24 | uint64(b[5])<<32 | uint64(b[6])<<40 | + uint64(b[7])<<48 | uint64(b[8])<<56, + false, 9 + } + + // 0-250: value of first byte + return uint64(b[0]), false, 1 +} + +// encodes a uint64 value and appends it to the given bytes slice +func appendLengthEncodedInteger(b []byte, n uint64) []byte { + switch { + case n <= 250: + return append(b, byte(n)) + + case n <= 0xffff: + return append(b, 0xfc, byte(n), byte(n>>8)) + + case n <= 0xffffff: + return append(b, 0xfd, byte(n), byte(n>>8), byte(n>>16)) + } + return append(b, 0xfe, byte(n), byte(n>>8), byte(n>>16), byte(n>>24), + byte(n>>32), byte(n>>40), byte(n>>48), byte(n>>56)) +} + +// reserveBuffer checks cap(buf) and expand buffer to len(buf) + appendSize. +// If cap(buf) is not enough, reallocate new buffer. +func reserveBuffer(buf []byte, appendSize int) []byte { + newSize := len(buf) + appendSize + if cap(buf) < newSize { + // Grow buffer exponentially + newBuf := make([]byte, len(buf)*2+appendSize) + copy(newBuf, buf) + buf = newBuf + } + return buf[:newSize] +} + +// escapeBytesBackslash escapes []byte with backslashes (\) +// This escapes the contents of a string (provided as []byte) by adding backslashes before special +// characters, and turning others into specific escape sequences, such as +// turning newlines into \n and null bytes into \0. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L823-L932 +func escapeBytesBackslash(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + switch c { + case '\x00': + buf[pos+1] = '0' + buf[pos] = '\\' + pos += 2 + case '\n': + buf[pos+1] = 'n' + buf[pos] = '\\' + pos += 2 + case '\r': + buf[pos+1] = 'r' + buf[pos] = '\\' + pos += 2 + case '\x1a': + buf[pos+1] = 'Z' + buf[pos] = '\\' + pos += 2 + case '\'': + buf[pos+1] = '\'' + buf[pos] = '\\' + pos += 2 + case '"': + buf[pos+1] = '"' + buf[pos] = '\\' + pos += 2 + case '\\': + buf[pos+1] = '\\' + buf[pos] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringBackslash is similar to escapeBytesBackslash but for string. +func escapeStringBackslash(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + switch c { + case '\x00': + buf[pos+1] = '0' + buf[pos] = '\\' + pos += 2 + case '\n': + buf[pos+1] = 'n' + buf[pos] = '\\' + pos += 2 + case '\r': + buf[pos+1] = 'r' + buf[pos] = '\\' + pos += 2 + case '\x1a': + buf[pos+1] = 'Z' + buf[pos] = '\\' + pos += 2 + case '\'': + buf[pos+1] = '\'' + buf[pos] = '\\' + pos += 2 + case '"': + buf[pos+1] = '"' + buf[pos] = '\\' + pos += 2 + case '\\': + buf[pos+1] = '\\' + buf[pos] = '\\' + pos += 2 + default: + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeBytesQuotes escapes apostrophes in []byte by doubling them up. +// This escapes the contents of a string by doubling up any apostrophes that +// it contains. This is used when the NO_BACKSLASH_ESCAPES SQL_MODE is in +// effect on the server. +// https://github.com/mysql/mysql-server/blob/mysql-5.7.5/mysys/charset.c#L963-L1038 +func escapeBytesQuotes(buf, v []byte) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for _, c := range v { + if c == '\'' { + buf[pos+1] = '\'' + buf[pos] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +// escapeStringQuotes is similar to escapeBytesQuotes but for string. +func escapeStringQuotes(buf []byte, v string) []byte { + pos := len(buf) + buf = reserveBuffer(buf, len(v)*2) + + for i := 0; i < len(v); i++ { + c := v[i] + if c == '\'' { + buf[pos+1] = '\'' + buf[pos] = '\'' + pos += 2 + } else { + buf[pos] = c + pos++ + } + } + + return buf[:pos] +} + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// Unlock is a no-op used by -copylocks checker from `go vet`. +// noCopy should implement sync.Locker from Go 1.11 +// https://github.com/golang/go/commit/c2eba53e7f80df21d51285879d51ab81bcfbf6bc +// https://github.com/golang/go/issues/26165 +func (*noCopy) Unlock() {} + +// atomicBool is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _noCopy noCopy + value uint32 +} + +// IsSet returns whether the current boolean value is true +func (ab *atomicBool) IsSet() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Set sets the value of the bool regardless of the previous value +func (ab *atomicBool) Set(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// TrySet sets the value of the bool and returns whether the value changed +func (ab *atomicBool) TrySet(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) == 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} + +// atomicError is a wrapper for atomically accessed error values +type atomicError struct { + _noCopy noCopy + value atomic.Value +} + +// Set sets the error value regardless of the previous value. +// The value must not be nil +func (ae *atomicError) Set(value error) { + ae.value.Store(value) +} + +// Value returns the current error value +func (ae *atomicError) Value() error { + if v := ae.value.Load(); v != nil { + // this will panic if the value doesn't implement the error interface + return v.(error) + } + return nil +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/utils_test.go b/utils_test.go new file mode 100644 index 0000000..7a87414 --- /dev/null +++ b/utils_test.go @@ -0,0 +1,510 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mariadb + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/binary" + "testing" + "time" +) + +func TestLengthEncodedInteger(t *testing.T) { + var integerTests = []struct { + num uint64 + encoded []byte + }{ + {0x0000000000000000, []byte{0x00}}, + {0x0000000000000012, []byte{0x12}}, + {0x00000000000000fa, []byte{0xfa}}, + {0x0000000000000100, []byte{0xfc, 0x00, 0x01}}, + {0x0000000000001234, []byte{0xfc, 0x34, 0x12}}, + {0x000000000000ffff, []byte{0xfc, 0xff, 0xff}}, + {0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}}, + {0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}}, + {0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}}, + {0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}}, + {0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}}, + {0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + + for _, tst := range integerTests { + num, isNull, numLen := readLengthEncodedInteger(tst.encoded) + if isNull { + t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num) + } + if num != tst.num { + t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num) + } + if numLen != len(tst.encoded) { + t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen) + } + encoded := appendLengthEncodedInteger(nil, num) + if !bytes.Equal(encoded, tst.encoded) { + t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded) + } + } +} + +func TestFormatBinaryDateTime(t *testing.T) { + rawDate := [11]byte{} + binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years + rawDate[2] = 12 // months + rawDate[3] = 30 // days + rawDate[4] = 15 // hours + rawDate[5] = 46 // minutes + rawDate[6] = 23 // seconds + binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds + expect := func(expected string, inlen, outlen uint8) { + actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for length in %d, out %d", + expected, actual, inlen, outlen, + ) + } + } + expect("0000-00-00", 0, 10) + expect("0000-00-00 00:00:00", 0, 19) + expect("1978-12-30", 4, 10) + expect("1978-12-30 15:46:23", 7, 19) + expect("1978-12-30 15:46:23.987654", 11, 26) +} + +func TestFormatBinaryTime(t *testing.T) { + expect := func(expected string, src []byte, outlen uint8) { + actual, _ := formatBinaryTime(src, outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for src=%q and outlen=%d", + expected, actual, src, outlen) + } + } + + // binary format: + // sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4) + + // Zeros + expect("00:00:00", []byte{}, 8) + expect("00:00:00.0", []byte{}, 10) + expect("00:00:00.000000", []byte{}, 15) + + // Without micro(4) + expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11) + expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8) + expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8) + expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8) + + // With micro(4) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11) + expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15) +} + +func TestEscapeBackslash(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesBackslash([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringBackslash([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\\0bar", "foo\x00bar") + expect("foo\\nbar", "foo\nbar") + expect("foo\\rbar", "foo\rbar") + expect("foo\\Zbar", "foo\x1abar") + expect("foo\\\"bar", "foo\"bar") + expect("foo\\\\bar", "foo\\bar") + expect("foo\\'bar", "foo'bar") +} + +func TestEscapeQuotes(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesQuotes([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringQuotes([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\x00bar", "foo\x00bar") // not affected + expect("foo\nbar", "foo\nbar") // not affected + expect("foo\rbar", "foo\rbar") // not affected + expect("foo\x1abar", "foo\x1abar") // not affected + expect("foo''bar", "foo'bar") // affected + expect("foo\"bar", "foo\"bar") // not affected +} + +func TestAtomicBool(t *testing.T) { + var ab atomicBool + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(true) + if ab.value != 1 { + t.Fatal("Set(true) did not set value to 1") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(false) + if ab.value != 0 { + t.Fatal("Set(false) did not set value to 0") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(false) + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + if ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to fail") + } + if !ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to succeed") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + if ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to fail") + } + if !ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to succeed") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + // we've "tested" them ¯\_(ツ)_/¯ + ab._noCopy.Lock() + defer ab._noCopy.Unlock() +} + +func TestAtomicError(t *testing.T) { + var ae atomicError + if ae.Value() != nil { + t.Fatal("Expected value to be nil") + } + + ae.Set(ErrMalformPkt) + if v := ae.Value(); v != ErrMalformPkt { + if v == nil { + t.Fatal("Value is still nil") + } + t.Fatal("Error did not match") + } + ae.Set(ErrPktSync) + if ae.Value() == ErrMalformPkt { + t.Fatal("Error still matches old error") + } + if v := ae.Value(); v != ErrPktSync { + t.Fatal("Error did not match") + } +} + +func TestIsolationLevelMapping(t *testing.T) { + data := []struct { + level driver.IsolationLevel + expected string + }{ + { + level: driver.IsolationLevel(sql.LevelReadCommitted), + expected: "READ COMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelRepeatableRead), + expected: "REPEATABLE READ", + }, + { + level: driver.IsolationLevel(sql.LevelReadUncommitted), + expected: "READ UNCOMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelSerializable), + expected: "SERIALIZABLE", + }, + } + + for i, td := range data { + if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil { + t.Fatal(i, td.expected, actual, err) + } + } + + // check unsupported mapping + expectedErr := "mysql: unsupported isolation level: 7" + actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable)) + if actual != "" || err == nil { + t.Fatal("Expected error on unsupported isolation level") + } + if err.Error() != expectedErr { + t.Fatalf("Expected error to be %q, got %q", expectedErr, err) + } +} + +func TestAppendDateTime(t *testing.T) { + tests := []struct { + t time.Time + str string + }{ + { + t: time.Date(1234, 5, 6, 0, 0, 0, 0, time.UTC), + str: "1234-05-06", + }, + { + t: time.Date(4567, 12, 31, 12, 0, 0, 0, time.UTC), + str: "4567-12-31 12:00:00", + }, + { + t: time.Date(2020, 5, 30, 12, 34, 0, 0, time.UTC), + str: "2020-05-30 12:34:00", + }, + { + t: time.Date(2020, 5, 30, 12, 34, 56, 0, time.UTC), + str: "2020-05-30 12:34:56", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123000000, time.UTC), + str: "2020-05-30 22:33:44.123", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456000, time.UTC), + str: "2020-05-30 22:33:44.123456", + }, + { + t: time.Date(2020, 5, 30, 22, 33, 44, 123456789, time.UTC), + str: "2020-05-30 22:33:44.123456789", + }, + { + t: time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC), + str: "9999-12-31 23:59:59.999999999", + }, + { + t: time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + str: "0001-01-01", + }, + } + for _, v := range tests { + buf := make([]byte, 0, 32) + buf, _ = appendDateTime(buf, v.t) + if str := string(buf); str != v.str { + t.Errorf("appendDateTime(%v), have: %s, want: %s", v.t, str, v.str) + } + } + + // year out of range + { + v := time.Date(0, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } + { + v := time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC) + buf := make([]byte, 0, 32) + _, err := appendDateTime(buf, v) + if err == nil { + t.Error("want an error") + return + } + } +} + +func TestParseDateTime(t *testing.T) { + cases := []struct { + name string + str string + }{ + { + name: "parse date", + str: "2020-05-13", + }, + { + name: "parse null date", + str: sDate0, + }, + { + name: "parse datetime", + str: "2020-05-13 21:30:45", + }, + { + name: "parse null datetime", + str: sDateTime0, + }, + { + name: "parse datetime nanosec 1-digit", + str: "2020-05-25 23:22:01.1", + }, + { + name: "parse datetime nanosec 2-digits", + str: "2020-05-25 23:22:01.15", + }, + { + name: "parse datetime nanosec 3-digits", + str: "2020-05-25 23:22:01.159", + }, + { + name: "parse datetime nanosec 4-digits", + str: "2020-05-25 23:22:01.1594", + }, + { + name: "parse datetime nanosec 5-digits", + str: "2020-05-25 23:22:01.15949", + }, + { + name: "parse datetime nanosec 6-digits", + str: "2020-05-25 23:22:01.159491", + }, + } + + for _, loc := range []*time.Location{ + time.UTC, + time.FixedZone("test", 8*60*60), + } { + for _, cc := range cases { + t.Run(cc.name+"-"+loc.String(), func(t *testing.T) { + var want time.Time + if cc.str != sDate0 && cc.str != sDateTime0 { + var err error + want, err = time.ParseInLocation(timeFormat[:len(cc.str)], cc.str, loc) + if err != nil { + t.Fatal(err) + } + } + got, err := parseDateTime([]byte(cc.str), loc) + if err != nil { + t.Fatal(err) + } + + if !want.Equal(got) { + t.Fatalf("want: %v, but got %v", want, got) + } + }) + } + } +} + +func TestParseDateTimeFail(t *testing.T) { + cases := []struct { + name string + str string + wantErr string + }{ + { + name: "parse invalid time", + str: "hello", + wantErr: "invalid time bytes: hello", + }, + { + name: "parse year", + str: "000!-00-00 00:00:00.000000", + wantErr: "not [0-9]", + }, + { + name: "parse month", + str: "0000-!0-00 00:00:00.000000", + wantErr: "not [0-9]", + }, + { + name: `parse "-" after parsed year`, + str: "0000:00-00 00:00:00.000000", + wantErr: "bad value for field: `:`", + }, + { + name: `parse "-" after parsed month`, + str: "0000-00:00 00:00:00.000000", + wantErr: "bad value for field: `:`", + }, + { + name: `parse " " after parsed date`, + str: "0000-00-00+00:00:00.000000", + wantErr: "bad value for field: `+`", + }, + { + name: `parse ":" after parsed date`, + str: "0000-00-00 00-00:00.000000", + wantErr: "bad value for field: `-`", + }, + { + name: `parse ":" after parsed hour`, + str: "0000-00-00 00:00-00.000000", + wantErr: "bad value for field: `-`", + }, + { + name: `parse "." after parsed sec`, + str: "0000-00-00 00:00:00?000000", + wantErr: "bad value for field: `?`", + }, + } + + for _, cc := range cases { + t.Run(cc.name, func(t *testing.T) { + got, err := parseDateTime([]byte(cc.str), time.UTC) + if err == nil { + t.Fatal("want error") + } + if cc.wantErr != err.Error() { + t.Fatalf("want `%s`, but got `%s`", cc.wantErr, err) + } + if !got.IsZero() { + t.Fatal("want zero time") + } + }) + } +} |