3 분 소요

개요


예제

  • 코드

        import pymysql
        import time
        import threading
        import unittest


        class MySQL:

            def __init__(self,
                         host,
                         port,
                         user,
                         password,
                         database=None,
                         charset='utf8',
                         autocommit=True):
                self.host = host
                self.port = port
                self.user = user
                self.password = password
                self.database = database
                self.charset = charset
                self.autocommit = autocommit

                self.connection_infos = {}
                self.lock = threading.Lock()

            def __del__(self):
                for value in self.connection_infos.values():
                    value.close()

                self.connection_infos.clear()

            def get_connection(self):
                with self.lock:
                    if (threading.get_ident() in self.connection_infos) == False:
                        self.connection_infos[threading.get_ident()] = pymysql.connect(
                            host=self.host,
                            port=self.port,
                            user=self.user,
                            password=self.password,
                            database=self.database,
                            charset=self.charset,
                            autocommit=self.autocommit)

                    self.connection_infos[threading.get_ident()].ping()

                    return self.connection_infos[threading.get_ident()]

            def get_cursor(self, cursor=None):
                return self.get_connection().cursor(cursor)


        class TestMySQL(unittest.TestCase):

            @classmethod
            def setUpClass(cls):
                cls.database_name = 'temp_' + str(int(time.time()))
                cls.table_name = 'temp'

                cls.mysql = MySQL(host='127.0.0.1',
                                  port=3306,
                                  user='root',
                                  password='root')

                cls.mysql.get_cursor().execute('CREATE DATABASE IF NOT EXISTS ' +
                                               cls.database_name + ';')

                cls.mysql.get_connection().select_db(cls.database_name)

                cls.mysql.get_cursor().execute('DROP TABLE IF EXISTS ' +
                                               cls.table_name + ';')
                cls.mysql.get_cursor().execute('CREATE TABLE IF NOT EXISTS ' +
                                               cls.table_name +
                                               '(id INT(10), name VARCHAR(50));')

                return

            @classmethod
            def tearDownClass(cls):
                cls.mysql.get_cursor().execute('DROP DATABASE IF EXISTS ' +
                                               cls.database_name + ';')
                return

            def setUp(self):
                self.cursor = self.mysql.get_cursor()
                return

            def tearDown(self):
                self.cursor.execute('DELETE FROM ' + self.table_name + ';')
                self.cursor.close()
                return

            def worker(self, mysql, count):
                mysql.get_connection().select_db(self.database_name)
                cursor = mysql.get_cursor()

                for i in range(0, count):
                    query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(1, \'a\');'
                    self.assertEqual(cursor.execute(query), 1)

            def test_multi_thread(self):
                threads = []

                count = 100
                thread_count = 100
                for i in range(0, thread_count):
                    t = threading.Thread(target=self.worker, args=(self.mysql, count))
                    t.start()
                    threads.append(t)

                for t in threads:
                    t.join()

                query = 'SELECT * FROM ' + self.table_name + ' ORDER BY id ASC;'
                self.assertEqual(self.cursor.execute(query), thread_count * count)

            def test_insert(self):
                query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(1, \'a\');'
                self.assertEqual(self.cursor.execute(query), 1)

                query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(2, \'b\'), (3, \'c\');'
                self.assertEqual(self.cursor.execute(query), 2)

                query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(%s, %s);'
                args = (4, 'd')
                self.assertEqual(
                    self.cursor.execute(query, args),
                    1,
                )

                query = 'INSERT INTO ' + self.table_name + '(id, name) VALUES(%s, %s);'
                args = [[5, 'e'], [6, 'f'], [7, 'g']]
                self.assertEqual(self.cursor.executemany(query, args), 3)

            def test_select(self):
                self.test_insert()

                query = 'SELECT * FROM ' + self.table_name + ' ORDER BY id ASC;'

                self.assertEqual(self.cursor.execute(query), 7)

                self.assertEqual(self.cursor.description[0][0], 'id')
                self.assertEqual(self.cursor.description[1][0], 'name')
                self.assertEqual(self.cursor.rowcount, 7)

                id, name = self.cursor.fetchone()
                self.assertEqual(id, 1)
                self.assertEqual(name, 'a')

                rows = {2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g'}
                for id, name in self.cursor.fetchmany(2):
                    self.assertEqual(name, rows[id])

                for id, name in self.cursor.fetchall():
                    self.assertEqual(name, rows[id])

                return

            def test_transaction_commit(self):
                mysql_another = MySQL(host='127.0.0.1',
                                      port=3306,
                                      user='root',
                                      password='root',
                                      database=self.database_name)
                cursor_another = mysql_another.get_cursor()

                self.mysql.get_connection().begin()

                self.assertEqual(
                    self.cursor.execute('INSERT INTO ' + self.table_name +
                                        '(id, name) VALUES(1, \'a\');'), 1)

                self.assertEqual(
                    self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
                self.assertEqual(self.cursor.rowcount, 1)

                self.assertEqual(
                    cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                    0)
                self.assertEqual(cursor_another.rowcount, 0)

                self.mysql.get_connection().commit()

                self.assertEqual(
                    self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
                self.assertEqual(self.cursor.rowcount, 1)

                self.assertEqual(
                    cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                    1)
                self.assertEqual(cursor_another.rowcount, 1)

                self.assertEqual(
                    self.cursor.execute('INSERT INTO ' + self.table_name +
                                        '(id, name) VALUES(2, \'b\');'), 1)

                self.assertEqual(
                    self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 2)
                self.assertEqual(self.cursor.rowcount, 2)

                self.assertEqual(
                    cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                    2)
                self.assertEqual(cursor_another.rowcount, 2)

            def test_transaction_rollback(self):
                mysql_another = MySQL(host='127.0.0.1',
                                      port=3306,
                                      user='root',
                                      password='root',
                                      database=self.database_name)
                cursor_another = mysql_another.get_cursor()

                self.mysql.get_connection().begin()

                self.assertEqual(
                    self.cursor.execute('INSERT INTO ' + self.table_name +
                                        '(id, name) VALUES(1, \'a\');'), 1)

                self.assertEqual(
                    self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
                self.assertEqual(self.cursor.rowcount, 1)

                self.assertEqual(
                    cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                    0)
                self.assertEqual(cursor_another.rowcount, 0)

                self.mysql.get_connection().rollback()

                self.assertEqual(
                    self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 0)
                self.assertEqual(self.cursor.rowcount, 0)

                self.assertEqual(
                    cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                    0)
                self.assertEqual(cursor_another.rowcount, 0)

                self.assertEqual(
                    self.cursor.execute('INSERT INTO ' + self.table_name +
                                        '(id, name) VALUES(2, \'b\');'), 1)

                self.assertEqual(
                    self.cursor.execute('SELECT * FROM ' + self.table_name + ';'), 1)
                self.assertEqual(self.cursor.rowcount, 1)

                self.assertEqual(
                    cursor_another.execute('SELECT * FROM ' + self.table_name + ';'),
                    1)
                self.assertEqual(cursor_another.rowcount, 1)

            def test_mogrify(self):
                query = 'SELECT * FROM test ORDER BY id ASC;'
                self.assertEqual(self.cursor.mogrify(query), query)

                query = 'INSERT INTO test(id, name) VALUES(%s, %s);'
                args = (1, 'a')
                self.assertEqual(self.cursor.mogrify(query, args),
                                 'INSERT INTO test(id, name) VALUES(1, \'a\');')


        if __name__ == '__main__':
            print('main call')
  • 실행 결과
    • python ./main.py

        main call
  • python -m unittest ./main.py

            ......
            ----------------------------------------------------------------------
            Ran 6 tests in 1.659s

            OK