今天来聊一下池化技术中的 MYSQL 连接池,首先了解一下实现的环境和库:

  • Ubuntu 22.04
  • libmysqlclient-dev

我个人是比较喜欢 Ubuntu 的,不过用什么发行版不冲突,有三方库就好了。

封装

查询结果类

我们首先将查询到的结果抽象出来,应该是一个二维数据,我们可以将其封装为一个类来获取 key 和对应的 value 值。

先看头文件,经过我的讲解之后再看源文件:

#ifndef RESULTSET_H
#define RESULTSET_H

#include <iostream>
#include <mutex>
#include <mysql/mysql.h>
#include <map>
#include <unistd.h>
#include <vector>

class CResultSet {
public:
    CResultSet(MYSQL_RES* res); // 用一个查询结果构造结果类
    int getInt(const char* key); // 获取int类型的结果
    std::string getString(const char* key); // 获取char*类型的结果
    bool Next(); // 下一行

private:
    int getIndex(const char* key); // 获取字段索引编号

private:
    MYSQL_RES* res_; // 查询结果
    MYSQL_ROW row_; // 当前行
    std::map<std::string, int> key_map_; // 字段索引
};

#endif

源文件解析

CResultSet::CResultSet(MYSQL_RES* res) {
    res_ = res;

    int num_fields = mysql_num_fields(res_); // 字段数量
    MYSQL_FIELD* fields = mysql_fetch_fields(res_); // 字段名
    for (int i = 0; i < num_fields; i++) {
        key_map_.insert(std::make_pair(fields[i].name, i)); // 字段名:编号 配对存储在索引中
        LogDebug("num_fields fields[{}].name: {}", i, fields[i].name);
    }
}
```cpp
CResultSet::~CResultSet() {
    if (res_) {
        mysql_free_result(res_); // 释放查询结果
        res_ = NULL;
    }
}
```cpp
int CResultSet::getInt(const char* key) {
    int idx = getIndex(key); // 获取字段的编号
    if (idx == -1) {
        return -1;
    }
    else {
        return atoi(row_[idx]); // 根据编号获取行中对应字段的值
    }
}
```cpp
// 头同上
char* CResultSet::getString(const char* key) {
    int idx = getIndex(key);
    if (idx == -1) {
        return NULL;
    }
    else {
        return row_[idx];
    }
}
```cpp
bool CResultSet::Next() {
    row_ = mysql_fetch_row(res_); // 有下一行就返回下一行
    if (row_) {
        return true;
    }
    else {
        return false;
    }
}
```cpp
int CResultSet::getIndex(const char* key) {
    std::map<std::string, int>::iterator ite = key_map_.find(key); // 迭代器

    if (ite == key_map_.end()) { // 没有找到
        return -1;
    }
    else {
        return ite->second; // 返回编号
    }
}

数据准备类

#ifndef PREPARESTATEMENT_H
#define PREPARESTATEMENT_H

#include <cstdint>
#include <mysql/mysql.h>
#include <iostream>
#include <string>

class CPrepareStatement {

public:
    CPrepareStatement();
    virtual ~CPrepareStatement();
    bool Init(MYSQL* mysql, std::string& sql); // 初始化
    void setParam(uint32_t index, int& value); // 设置键值
    void setParam(uint32_t index, uint32_t& value);
    void setParam(uint32_t index, std::string& value);
    void setParam(uint32_t index, const std::string& value);

    bool ExecuteUpdate(); // 执行SQL语句
    int getInsertId(); // 获取插入操作返回的唯一ID,确保唯一性
private:
    MYSQL_STMT* stmt_;[ ]()// SQL预处理模板
    MYSQL_BIND* param_bind_; // 参数绑定
    uint32_t param_cnt_; // 参数数量
};

#endif

源文件

#include "preparestatement.h"
#include "DLog.h"
#include "field_types.h"
#include "mysql.h"
#include <cstdint>
#include <string>

CPrepareStatement::CPrepareStatement() {
    stmt_ = NULL;
    param_bind_ = NULL;
    param_cnt_ = 0;
}

CPrepareStatement::~CPrepareStatement() {
    if (stmt_) {
        mysql_stmt_close(stmt_); // 关闭预处理模板
        stmt_ = NULL;
    }

    if (param_bind_) {
        delete[] param_bind_; // 释放绑定测参数
        param_bind_ = NULL;
    }
}

bool CPrepareStatement::Init(MYSQL* mysql, std::string& sql) {
    mysql_ping(mysql); // ping一下 保持连接畅通

    stmt_ = mysql_stmt_init(mysql);
    if (!stmt_) {
        LogErr("mysql stmt init failed!");
        return false;
    }
    // 准备模板的SQL语句
    if (mysql_stmt_prepare(stmt_, sql.c_str(), sql.size())) {
        LogErr("mysql stmt prepare failed: {}!", mysql_stmt_errno(stmt_));
        return false;
    }
    // 参数统计
    param_cnt_ = mysql_stmt_param_count(stmt_);
    if (param_cnt_ > 0) {
        param_bind_ = new MYSQL_BIND[param_cnt_];
        if (!param_bind_) {
            LogErr("new failed!");
            return false;
        }
        memset(param_bind_, 0, param_cnt_ * sizeof(MYSQL_BIND));
    }

    return true;
}

void CPrepareStatement::setParam(uint32_t index, int& value) {

    if (index >= param_cnt_) {
        LogErr("index too large long: {}!", index);
        return;
    }

    param_bind_[index].buffer_type = MYSQL_TYPE_LONG; // 类型
    param_bind_[index].buffer = &value; // 参数
}

void CPrepareStatement::setParam(uint32_t index, uint32_t& value) {

    if (index >= param_cnt_) {
        LogErr("index too large long: {}!", index);
        return;
    }

    param_bind_[index].buffer_type = MYSQL_TYPE_LONG;
    param_bind_[index].buffer = &value;
}

void CPrepareStatement::setParam(uint32_t index, std::string& value) {

    if (index >= param_cnt_) {
        LogErr("index too large long: {}!", index);
        return;
    }

    param_bind_[index].buffer_type = MYSQL_TYPE_STRING;
    param_bind_[index].buffer = (char*)value.c_str();
    param_bind_[index].buffer_length = value.size(); // 字符串长度
}

void CPrepareStatement::setParam(uint32_t index, const std::string& value) {

    if (index >= param_cnt_) {
        LogErr("index too large long: {}", index);
        return;
    }

    param_bind_[index].buffer_type = MYSQL_TYPE_STRING;
    param_bind_[index].buffer = (char*)value.c_str();
    param_bind_[index].buffer_length = value.size();
}

bool CPrepareStatement::ExecuteUpdate() {

    if (!stmt_) {
        return false;
    }

    if (mysql_stmt_bind_param(stmt_, param_bind_)) {
        LogErr("stmt bind param failed: {}!", mysql_stmt_error(stmt_));
        return false;
    }

    if (mysql_stmt_execute(stmt_)) {
        LogErr("stmt execute failed: {}!", mysql_stmt_error(stmt_));
        return false;
    }

    if (mysql_stmt_affected_rows(stmt_) == 0) {
        LogErr("stmt affect rows failed: {}!", mysql_stmt_error(stmt_));
        return false;
    }

    return true;
}

int CPrepareStatement::getInsertId() {
    return mysql_stmt_insert_id(stmt_); // 获取插入操作的ID
}