#include "cppconn/driver.h"
#include "cppconn/prepared_statement.h"
#include <mutex>
#include <thread>
using namespace std;
class MySQL {
private:
string hostName;
string userName;
string password;
mutex mutexForInfo;
map<thread::id, shared_ptr<sql::Connection>> connectionInfos;
public:
MySQL(const string &hostName, const string &userName,
const string &password);
~MySQL();
shared_ptr<sql::Connection> GetConnection();
unique_ptr<sql::Statement> GetStatement();
unique_ptr<sql::PreparedStatement>
GetPreparedStatement(const string &sql);
void SetHostName(const string &hostName);
void SetUserName(const string &userName);
void SetPassword(const string &password);
};
MySQL::MySQL(const string &hostName, const string &userName,
const string &password)
: hostName(hostName), userName(userName), password(password),
connectionInfos({}) {}
MySQL::~MySQL() { this->connectionInfos.clear(); }
shared_ptr<sql::Connection> MySQL::GetConnection() {
lock_guard<mutex> lock(this->mutexForInfo);
if (this->connectionInfos.find(this_thread::get_id()) ==
this->connectionInfos.end()) {
try {
this->connectionInfos[this_thread::get_id()].reset(
get_driver_instance()->connect(this->hostName, this->userName,
this->password));
} catch (sql::SQLException &e) {
throw e;
}
}
return this->connectionInfos[this_thread::get_id()];
}
unique_ptr<sql::Statement> MySQL::GetStatement() {
return unique_ptr<sql::Statement>(this->GetConnection()->createStatement());
}
unique_ptr<sql::PreparedStatement>
MySQL::GetPreparedStatement(const string &sql) {
return unique_ptr<sql::PreparedStatement>(
this->GetConnection()->preparedStatement(sql));
}
void MySQL::SetHostName(const string &hostName) {
lock_guard<mutex> lock(this->mutexForInfo);
if (this->hostName == hostName) {
return;
}
this->hostName = hostName;
this->connectionInfos.clear();
}
void MySQL::SetUserName(const string &userName) {
lock_guard<mutex> lock(this->mutexForInfo);
if (this->userName == userName) {
return;
}
this->userName = userName;
this->connectionInfos.clear();
}
void MySQL::SetPassword(const string &password) {
lock_guard<mutex> lock(this->mutexForInfo);
if (this->password == password) {
return;
}
this->password = password;
this->connectionInfos.clear();
}
const auto hostName = string("tcp://127.0.0.1:3306");
const auto userName = string("root");
const auto password = string("root");
const auto databaseName = string("test");
const auto tableName = databaseName + ".temp";
MySQL mysql(hostName, userName, password);
bool initialize() {
cout << "----- " << __func__ << " start -----" << endl;
try {
auto statement = mysql.GetStatement();
statement->execute(string("DROP DATABASE IF EXISTS ") + databaseName +
";");
statement->execute(string("CREATE DATABASE IF NOT EXISTS ") +
databaseName + ";");
statement->execute(string("CREATE TABLE ") + tableName +
"(id INT(10), name VARCHAR(50));");
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool finalize() {
cout << "----- " << __func__ << " start -----" << endl;
try {
mysql.GetStatement()->execute(string("DROP DATABASE IF EXISTS ") +
databaseName + ";");
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool execute_test() {
cout << "----- " << __func__ << " start -----" << endl;
auto statementTest = []() {
auto statement = mysql.GetStatement();
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(1, 'a');");
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(2, 'b');");
statement->execute(string("DELETE FROM ") + tableName + ";");
};
auto preparedStatementTest = []() {
{
auto preparedStatement =
mysql.GetPreparedStatement(string("INSERT INTO ") + tableName +
"(id, name) VALUES(?, ?);");
preparedStatement->setInt(1, 1);
preparedStatement->setString(2, "a");
preparedStatement->execute();
preparedStatement->setInt(1, 2);
preparedStatement->setString(2, "b");
preparedStatement->execute();
}
{
auto preparedStatement = mysql.GetPreparedStatement(
string("DELETE FROM ") + tableName + ";");
preparedStatement->execute();
}
};
try {
statementTest();
preparedStatementTest();
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool executeQuery_test() {
cout << "----- " << __func__ << " start -----" << endl;
auto insert_data = []() {
auto statement = mysql.GetStatement();
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(1, 'a');");
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(2, 'b');");
cout << "~~~ : (" << statement->getUpdateCount() << ")" << endl;
};
auto statementTest = []() {
auto result =
unique_ptr<sql::ResultSet>(mysql.GetStatement()->executeQuery(
string("SELECT * FROM ") + tableName + " ORDER BY id ASC;"));
while (result->next()) {
cout << result->getInt(1) << ", " << result->getString(2) << endl;
}
};
auto preparedStatementTest = []() {
auto preparedStatement = mysql.GetPreparedStatement(
string("SELECT * FROM ") + tableName + ";");
auto result =
unique_ptr<sql::ResultSet>(preparedStatement->executeQuery());
while (result->next()) {
cout << result->getInt(1) << ", " << result->getString(2) << endl;
}
};
auto delete_data = []() {
mysql.GetStatement()->execute(string("DELETE FROM ") + tableName + ";");
};
try {
insert_data();
statementTest();
preparedStatementTest();
delete_data();
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool executeUpdate_test() {
cout << "----- " << __func__ << " start -----" << endl;
auto insert_data = []() {
auto statement = mysql.GetStatement();
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(1, 'a');");
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(2, 'b');");
};
auto statementTest = []() {
auto statement = mysql.GetStatement();
cout << statement->executeUpdate(string("UPDATE ") + tableName +
" SET name='c' WHERE id=1;")
<< endl;
cout << statement->executeUpdate(string("UPDATE ") + tableName +
" SET name='d';")
<< endl;
};
auto preparedStatementTest = []() {
auto preparedStatement = mysql.GetPreparedStatement(
string("UPDATE ") + tableName + " SET name=? WHERE id=?;");
preparedStatement->setString(1, "a");
preparedStatement->setInt(2, 1);
cout << preparedStatement->executeUpdate() << endl;
};
auto delete_data = []() {
cout << mysql.GetStatement()->executeUpdate(string("DELETE FROM ") +
tableName + ";")
<< endl;
};
try {
insert_data();
statementTest();
preparedStatementTest();
delete_data();
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool transaction_test_1() {
cout << "----- " << __func__ << " start -----" << endl;
auto select = [](MySQL &mysql) {
auto result =
unique_ptr<sql::ResultSet>(mysql.GetStatement()->executeQuery(
string("SELECT * FROM ") + tableName + " ORDER BY id ASC;"));
while (result->next()) {
cout << "\t" << result->getInt(1) << ", " << result->getString(2)
<< endl;
}
};
try {
auto connection = mysql.GetConnection();
auto statement = mysql.GetStatement();
connection->setAutoCommit(false);
{
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(1, 'a');");
cout << "commit before start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
MySQL temp(hostName, userName, password);
select(temp);
cout << "commit before end" << endl;
connection->commit();
cout << "commit after start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
select(temp);
cout << "commit after end" << endl;
}
{
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(2, 'b');");
cout << "rollback before start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
MySQL temp(hostName, userName, password);
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
select(temp);
cout << "rollback before end" << endl;
connection->rollback();
cout << "rollback after start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
select(temp);
cout << "rollback after end" << endl;
}
connection->setAutoCommit(true);
statement->execute(string("DELETE FROM ") + tableName + ";");
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool transaction_test_2() {
cout << "----- " << __func__ << " start -----" << endl;
auto select = [](MySQL &mysql) {
auto result =
unique_ptr<sql::ResultSet>(mysql.GetStatement()->executeQuery(
string("SELECT * FROM ") + tableName + " ORDER BY id ASC;"));
while (result->next()) {
cout << "\t" << result->getInt(1) << ", " << result->getString(2)
<< endl;
}
};
try {
auto statement = mysql.GetStatement();
{
statement->execute("START TRANSACTION;");
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(1, 'a');");
cout << "commit before start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
MySQL temp(hostName, userName, password);
select(temp);
cout << "commit before end" << endl;
statement->execute("COMMIT;");
cout << "commit after start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
select(temp);
cout << "commit after end" << endl;
}
{
statement->execute("START TRANSACTION;");
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(2, 'b');");
cout << "rollback before start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
MySQL temp(hostName, userName, password);
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
select(temp);
cout << "rollback before end" << endl;
statement->execute("ROLLBACK;");
cout << "rollback after start" << endl;
cout << "\t~~~ transaction connection ~~~" << endl;
select(mysql);
cout << "\t~~~ other connection ~~~" << endl;
select(temp);
cout << "rollback after end" << endl;
}
statement->execute(string("DELETE FROM ") + tableName + ";");
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
bool savepoint_test() {
cout << "----- " << __func__ << " start -----" << endl;
auto select = [](MySQL &mysql) {
auto result =
unique_ptr<sql::ResultSet>(mysql.GetStatement()->executeQuery(
string("SELECT * FROM ") + tableName + " ORDER BY id ASC;"));
while (result->next()) {
cout << "\t" << result->getInt(1) << ", " << result->getString(2)
<< endl;
}
};
try {
auto connection = mysql.GetConnection();
auto statement = mysql.GetStatement();
connection->setAutoCommit(false);
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(1, 'a');");
auto savePoint = unique_ptr<sql::Savepoint>(
connection->setSavepoint("save_point_1"));
statement->execute(string("INSERT INTO ") + tableName +
"(id, name) VALUES(2, 'b');");
cout << "rollback before start" << endl;
select(mysql);
cout << "rollback before end" << endl;
connection->rollback(savePoint.get());
connection->releaseSavepoint(savePoint.get());
cout << "rollback after start" << endl;
select(mysql);
cout << "rollback after end" << endl;
connection->commit();
connection->setAutoCommit(true);
} catch (sql::SQLException &e) {
cout << e.what() << endl;
return false;
}
cout << "----- " << __func__ << " end -----" << endl << endl;
return true;
}
int main() {
cout << "main start" << endl << endl;
if (initialize() == false) {
return EXIT_FAILURE;
}
if (execute_test() == false) {
return EXIT_FAILURE;
}
if (executeQuery_test() == false) {
return EXIT_FAILURE;
}
if (executeUpdate_test() == false) {
return EXIT_FAILURE;
}
if (transaction_test_1() == false) {
return EXIT_FAILURE;
}
if (transaction_test_2() == false) {
return EXIT_FAILURE;
}
if (savepoint_test() == false) {
return EXIT_FAILURE;
}
if (finalize() == false) {
return EXIT_FAILURE;
}
cout << "main end" << endl;
return EXIT_SUCCESS;
}