import pymysql
import time
import threading
import unittest
class MySQL:
def __init__(self,
host,
port,
user,
password,
database=None,
charset='utf8',
autocommit=True):
self.host = host
self.port = port
self.user = user
self.password = password
self.database = database
self.charset = charset
self.autocommit = autocommit
self.connection_infos = {}
self.lock = threading.Lock()
def __del__(self):
for value in self.connection_infos.values():
value.close()
self.connection_infos.clear()
def get_connection(self):
with self.lock:
if (threading.get_ident() in self.connection_infos) == False:
self.connection_infos[threading.get_ident()] = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset=self.charset,
autocommit=self.autocommit)
self.connection_infos[threading.get_ident()].ping()
return self.connection_infos[threading.get_ident()]
def get_cursor(self, cursor=None):
return self.get_connection().cursor(cursor)
class TestMySQL(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.database_name = 'temp_' + str(int(time.time()))
cls.table_name = 'temp'
cls.mysql = MySQL(host='127.0.0.1',
port=3306,
user='root',
password='root')
cls.mysql.get_cursor().execute('CREATE DATABASE IF NOT EXISTS ' +
cls.database_name + ';')
cls.mysql.get_connection().select_db(cls.database_name)
cls.mysql.get_cursor().execute('DROP TABLE IF EXISTS ' +
cls.table_name + ';')
cls.mysql.get_cursor().execute('CREATE TABLE IF NOT EXISTS ' +
cls.table_name +
'(id INT(10), name VARCHAR(50));')
return
@classmethod
def tearDownClass(cls):
cls.mysql.get_cursor().execute('DROP DATABASE IF EXISTS ' +
cls.database_name + ';')
return
def setUp(self):
self.cursor = self.mysql.get_cursor()
return
def tearDown(self):
self.cursor.execute('DELETE FROM ' + self.table_name + ';')
self.cursor.close()
return
def worker(self, mysql, count):
mysql.get_connection().select_db(self.database_name)
cursor = mysql.get_cursor()
for i in range(0, count):
query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(1, \'a\');'
self.assertEqual(cursor.execute(query), 1)
def test_multi_thread(self):
threads = []
count = 100
thread_count = 100
for i in range(0, thread_count):
t = threading.Thread(target=self.worker, args=(self.mysql, count))
t.start()
threads.append(t)
for t in threads:
t.join()
query = 'SELECT * FROM ' + self.table_name + ' ORDER BY id ASC;'
self.assertEqual(self.cursor.execute(query), thread_count * count)
def test_insert(self):
query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(1, \'a\');'
self.assertEqual(self.cursor.execute(query), 1)
query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(2, \'b\'), (3, \'c\');'
self.assertEqual(self.cursor.execute(query), 2)
query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(%s, %s);'
args = (4, 'd')
self.assertEqual(
self.cursor.execute(query, args),
1,
)
query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(%s, %s);'
args = [[5, 'e'], [6, 'f'], [7, 'g']]
self.assertEqual(self.cursor.executemany(query, args), 3)
def test_select(self):
self.test_insert()
query = 'SELECT * FROM ' + self.table_name + ' ORDER BY id ASC;'
self.assertEqual(self.cursor.execute(query), 7)
self.assertEqual(self.cursor.description[0][0], 'id')
self.assertEqual(self.cursor.description[1][0], 'name')
self.assertEqual(self.cursor.rowcount, 7)
id, name = self.cursor.fetchone()
self.assertEqual(id, 1)
self.assertEqual(name, 'a')
rows = {2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g'}
for id, name in self.cursor.fetchmany(2):
self.assertEqual(name, rows[id])
for id, name in self.cursor.fetchall():
self.assertEqual(name, rows[id])
return
def test_transaction_commit(self):
mysql_another = MySQL(host='127.0.0.1',
port=3306,
user='root',
password='root',
database=self.database_name)
cursor_another = mysql_another.get_cursor()
self.mysql.get_connection().begin()
self.assertEqual(
self.cursor.execute('INSERT INTO ' + self.table_name +
'(id, name) VALUES(1, \'a\');'), 1)
self.assertEqual(
self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
self.assertEqual(self.cursor.rowcount, 1)
self.assertEqual(
cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
0)
self.assertEqual(cursor_another.rowcount, 0)
self.mysql.get_connection().commit()
self.assertEqual(
self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
self.assertEqual(self.cursor.rowcount, 1)
self.assertEqual(
cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
1)
self.assertEqual(cursor_another.rowcount, 1)
self.assertEqual(
self.cursor.execute('INSERT INTO ' + self.table_name +
'(id, name) VALUES(2, \'b\');'), 1)
self.assertEqual(
self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 2)
self.assertEqual(self.cursor.rowcount, 2)
self.assertEqual(
cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
2)
self.assertEqual(cursor_another.rowcount, 2)
def test_transaction_rollback(self):
mysql_another = MySQL(host='127.0.0.1',
port=3306,
user='root',
password='root',
database=self.database_name)
cursor_another = mysql_another.get_cursor()
self.mysql.get_connection().begin()
self.assertEqual(
self.cursor.execute('INSERT INTO ' + self.table_name +
'(id, name) VALUES(1, \'a\');'), 1)
self.assertEqual(
self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
self.assertEqual(self.cursor.rowcount, 1)
self.assertEqual(
cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
0)
self.assertEqual(cursor_another.rowcount, 0)
self.mysql.get_connection().rollback()
self.assertEqual(
self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 0)
self.assertEqual(self.cursor.rowcount, 0)
self.assertEqual(
cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
0)
self.assertEqual(cursor_another.rowcount, 0)
self.assertEqual(
self.cursor.execute('INSERT INTO ' + self.table_name +
'(id, name) VALUES(2, \'b\');'), 1)
self.assertEqual(
self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
self.assertEqual(self.cursor.rowcount, 1)
self.assertEqual(
cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
1)
self.assertEqual(cursor_another.rowcount, 1)
def test_mogrify(self):
query = 'SELECT * FROM test ORDER BY id ASC;'
self.assertEqual(self.cursor.mogrify(query), query)
query = 'INSERT INTO test(id, name) VALUES(%s, %s);'
args = (1, 'a')
self.assertEqual(self.cursor.mogrify(query, args),
'INSERT INTO test(id, name) VALUES(1, \'a\');')
if __name__ == '__main__':
print('main call')