Impl GetDbNames, GetDbTalbes and ExecDbQuery

This commit is contained in:
Changhua 2023-02-17 22:06:40 +08:00
parent 33a5ed4033
commit 95ce4578bf
7 changed files with 277 additions and 56 deletions

View File

@ -2,6 +2,7 @@
#include <map>
#include <string>
#include <vector>
using namespace std;
@ -16,3 +17,19 @@ typedef struct {
string province;
string city;
} RpcContact_t;
typedef vector<string> DbNames_t;
typedef struct {
string name;
string sql;
} DbTable_t;
typedef vector<DbTable_t> DbTables_t;
typedef struct {
int32_t type;
string column;
vector<uint8_t> content;
} DbField_t;
typedef vector<DbField_t> DbRow_t;
typedef vector<DbRow_t> DbRows_t;

View File

@ -23,6 +23,7 @@ bool encode_string(pb_ostream_t *stream, const pb_field_t *field, void *const *a
const char *str = (const char *)*arg;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
@ -35,11 +36,24 @@ bool decode_string(pb_istream_t *stream, const pb_field_t *field, void **arg)
size_t len = stream->bytes_left;
str->resize(len);
if (!pb_read(stream, (uint8_t *)str->data(), len)) {
LOG_ERROR("Decoding failed: {}", PB_GET_ERROR(stream));
return false;
}
return true;
}
bool encode_bytes(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
vector<uint8_t> *v = (vector<uint8_t> *)*arg;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
return pb_encode_string(stream, (uint8_t *)v->data(), v->size());
}
bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
MsgTypes_t *m = (MsgTypes_t *)*arg;
@ -51,10 +65,12 @@ bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *ar
message.value.arg = (void *)it->second.c_str();
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, MsgTypes_TypesEntry_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
@ -90,10 +106,112 @@ bool encode_contacts(pb_ostream_t *stream, const pb_field_t *field, void *const
message.gender = (*it).gender;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, RpcContact_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
bool encode_dbnames(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
vector<string> *v = (vector<string> *)*arg;
DbNames message = DbNames_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.names.funcs.encode = &encode_string;
message.names.arg = (void *)(*it).c_str();
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbNames_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
bool encode_tables(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
DbTables_t *v = (DbTables_t *)*arg;
DbTable message = DbTable_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.name.funcs.encode = &encode_string;
message.name.arg = (void *)(*it).name.c_str();
message.sql.funcs.encode = &encode_string;
message.sql.arg = (void *)(*it).sql.c_str();
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbTable_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
static bool encode_fields(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
DbRow_t *v = (DbRow_t *)*arg;
DbField message = DbField_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.type = (*it).type;
message.column.arg = (void *)(*it).column.c_str();
message.column.funcs.encode = &encode_string;
message.content.arg = (void *)&(*it).content;
message.content.funcs.encode = &encode_bytes;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbField_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}
return true;
}
bool encode_rows(pb_ostream_t *stream, const pb_field_t *field, void *const *arg)
{
DbRows_t *v = (DbRows_t *)*arg;
DbRow message = DbRow_init_default;
for (auto it = v->begin(); it != v->end(); it++) {
message.fields.arg = (void *)&(*it);
message.fields.funcs.encode = &encode_fields;
if (!pb_encode_tag_for_field(stream, field)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
if (!pb_encode_submessage(stream, DbRow_fields, &message)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(stream));
return false;
}
}

View File

@ -8,3 +8,6 @@ bool encode_string(pb_ostream_t *stream, const pb_field_t *field, void *const *a
bool decode_string(pb_istream_t *stream, const pb_field_t *field, void **arg);
bool encode_types(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_contacts(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_dbnames(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_tables(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);
bool encode_rows(pb_ostream_t *stream, const pb_field_t *field, void *const *arg);

View File

@ -3,3 +3,7 @@
* fallback_type:FT_POINTER
MsgTypes* fallback_type:FT_CALLBACK
RpcContact* fallback_type:FT_CALLBACK
DbNames* fallback_type:FT_CALLBACK
DbTable* fallback_type:FT_CALLBACK
DbField* fallback_type:FT_CALLBACK
DbRow* fallback_type:FT_CALLBACK

View File

@ -1,13 +1,9 @@
#include <algorithm>
#include <map>
#include <string>
#if 0
#include <iterator>
#include "exec_sql.h"
#include "load_calls.h"
#include "util.h"
using namespace std;
#define SQLITE_OK 0 /* Successful result */
#define SQLITE_ERROR 1 /* Generic error */
#define SQLITE_INTERNAL 2 /* Internal logic error in SQLite */
@ -71,22 +67,6 @@ typedef const void *(__cdecl *Sqlite3_column_blob)(DWORD *, int);
typedef int(__cdecl *Sqlite3_column_bytes)(DWORD *, int);
typedef int(__cdecl *Sqlite3_finalize)(DWORD *);
static int cbGetTables(void *ret, int argc, char **argv, char **azColName)
{
wcf::DbTables *tbls = (wcf::DbTables *)ret;
wcf::DbTable *tbl = tbls->add_tables();
for (int i = 0; i < argc; i++) {
if (strcmp(azColName[i], "name") == 0) {
tbl->set_name(argv[i] ? argv[i] : "");
} else if (strcmp(azColName[i], "sql") == 0) {
string sql(argv[i]);
sql.erase(std::remove(sql.begin(), sql.end(), '\t'), sql.end());
tbl->set_sql(sql.c_str());
}
}
return 0;
}
dbMap_t GetDbHandles()
{
if (!dbMap.empty())
@ -109,37 +89,60 @@ dbMap_t GetDbHandles()
return dbMap;
}
void GetDbNames(wcf::DbNames *names)
DbNames_t GetDbNames()
{
DbNames_t names;
if (dbMap.empty()) {
dbMap = GetDbHandles();
}
for (auto &[k, v] : dbMap) {
auto *name = names->add_names();
name->assign(k);
names.push_back(k);
}
return names;
}
void GetDbTables(const string db, wcf::DbTables *tables)
static int cbGetTables(void *ret, int argc, char **argv, char **azColName)
{
DbTables_t *tbls = (DbTables_t *)ret;
DbTable_t tbl;
for (int i = 0; i < argc; i++) {
if (strcmp(azColName[i], "name") == 0) {
tbl.name = argv[i] ? argv[i] : "";
} else if (strcmp(azColName[i], "sql") == 0) {
string sql(argv[i]);
sql.erase(std::remove(sql.begin(), sql.end(), '\t'), sql.end());
tbl.sql = sql.c_str();
}
}
tbls->push_back(tbl);
return 0;
}
DbTables_t GetDbTables(const string db)
{
DbTables_t tables;
if (dbMap.empty()) {
dbMap = GetDbHandles();
}
auto it = dbMap.find(db);
if (it == dbMap.end()) {
return; // DB not found
return tables; // DB not found
}
const char *sql = "select name, sql from sqlite_master where type=\"table\";";
Sqlite3_exec p_Sqlite3_exec = (Sqlite3_exec)(g_WeChatWinDllAddr + g_WxCalls.sql.exec);
p_Sqlite3_exec(it->second, sql, (sqlite3_callback)cbGetTables, tables, 0);
p_Sqlite3_exec(it->second, sql, (sqlite3_callback)cbGetTables, (void *)&tables, 0);
return tables;
}
void ExecDbQuery(const string db, const string sql, wcf::DbRows *rows)
DbRows_t ExecDbQuery(const string db, const string sql)
{
DbRows_t rows;
Sqlite3_prepare func_prepare = (Sqlite3_prepare)(g_WeChatWinDllAddr + 0x14227F0);
Sqlite3_step func_step = (Sqlite3_step)(g_WeChatWinDllAddr + 0x13EA780);
Sqlite3_column_count func_column_count = (Sqlite3_column_count)(g_WeChatWinDllAddr + 0x13EACD0);
@ -156,22 +159,26 @@ void ExecDbQuery(const string db, const string sql, wcf::DbRows *rows)
DWORD *stmt;
int rc = func_prepare(dbMap[db], sql.c_str(), -1, &stmt, 0);
if (rc != SQLITE_OK) {
return;
return rows;
}
while (func_step(stmt) == SQLITE_ROW) {
wcf::DbRow *row = rows->add_rows();
int col_count = func_column_count(stmt);
DbRow_t row;
int col_count = func_column_count(stmt);
for (int i = 0; i < col_count; i++) {
wcf::DbField *field = row->add_fields();
field->set_type(func_column_type(stmt, i));
field->set_column(func_column_name(stmt, i));
DbField_t field;
field.type = func_column_type(stmt, i);
field.column = func_column_name(stmt, i);
int length = func_column_bytes(stmt, i);
const void *blob = func_column_blob(stmt, i);
if (length && (field->type() != 5)) {
field->set_content(string((char *)blob, length));
if (length && (field.type != 5)) {
field.content.reserve(length);
copy((uint8_t *)blob, (uint8_t *)blob + length, back_inserter(field.content));
}
row.push_back(field);
}
rows.push_back(row);
}
return rows;
}
#endif

View File

@ -1,11 +1,7 @@
#pragma once
#if 0
#include <string>
#include <vector>
#include "../proto/wcf.grpc.pb.h"
#include "pb_types.h"
void GetDbNames(wcf::DbNames *names);
void GetDbTables(const std::string db, wcf::DbTables *tables);
void ExecDbQuery(const std::string db, const std::string sql, wcf::DbRows *rows);
#endif
DbNames_t GetDbNames();
DbTables_t GetDbTables(const string db);
DbRows_t ExecDbQuery(const string db, const string sql);

View File

@ -25,7 +25,7 @@
#include "spy_types.h"
#include "util.h"
#define G_BUF_SIZE (1024 * 1024)
#define G_BUF_SIZE (16 * 1024 * 1024)
extern int IsLogin(void); // Defined in spy.cpp
extern std::string GetSelfWxid(); // Defined in spy.cpp
@ -51,7 +51,7 @@ bool func_is_login(uint8_t *out, size_t *len)
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream));
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
@ -68,7 +68,7 @@ bool func_get_self_wxid(uint8_t *out, size_t *len)
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream));
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
@ -88,7 +88,7 @@ bool func_get_msg_types(uint8_t *out, size_t *len)
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream));
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
@ -102,13 +102,73 @@ bool func_get_contacts(uint8_t *out, size_t *len)
rsp.func = Functions_FUNC_GET_CONTACTS;
rsp.which_msg = Response_contacts_tag;
vector<RpcContact_t> contacts = GetContacts();
rsp.msg.types.types.funcs.encode = encode_contacts;
rsp.msg.types.types.arg = &contacts;
vector<RpcContact_t> contacts = GetContacts();
rsp.msg.contacts.contacts.funcs.encode = encode_contacts;
rsp.msg.contacts.contacts.arg = &contacts;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
printf("Encoding failed: %s\n", PB_GET_ERROR(&stream));
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
return true;
}
bool func_get_db_names(uint8_t *out, size_t *len)
{
Response rsp = Response_init_default;
rsp.func = Functions_FUNC_GET_DB_NAMES;
rsp.which_msg = Response_dbs_tag;
DbNames_t dbnames = GetDbNames();
rsp.msg.dbs.names.funcs.encode = encode_dbnames;
rsp.msg.dbs.names.arg = &dbnames;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
return true;
}
bool func_get_db_tables(char *db, uint8_t *out, size_t *len)
{
Response rsp = Response_init_default;
rsp.func = Functions_FUNC_GET_DB_TABLES;
rsp.which_msg = Response_tables_tag;
DbTables_t tables = GetDbTables(db);
rsp.msg.tables.tables.funcs.encode = encode_tables;
rsp.msg.tables.tables.arg = &tables;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
return true;
}
bool func_exec_db_query(char *db, char *sql, uint8_t *out, size_t *len)
{
Response rsp = Response_init_default;
rsp.func = Functions_FUNC_GET_DB_TABLES;
rsp.which_msg = Response_rows_tag;
DbRows_t rows = ExecDbQuery(db, sql);
rsp.msg.rows.rows.arg = &rows;
rsp.msg.rows.rows.funcs.encode = encode_rows;
pb_ostream_t stream = pb_ostream_from_buffer(out, *len);
if (!pb_encode(&stream, Response_fields, &rsp)) {
LOG_ERROR("Encoding failed: {}", PB_GET_ERROR(&stream));
return false;
}
*len = stream.bytes_written;
@ -149,6 +209,21 @@ static bool dispatcher(uint8_t *in, size_t in_len, uint8_t *out, size_t *out_len
ret = func_get_contacts(out, out_len);
break;
}
case Functions_FUNC_GET_DB_NAMES: {
LOG_INFO("[Functions_FUNC_GET_DB_NAMES]");
ret = func_get_db_names(out, out_len);
break;
}
case Functions_FUNC_GET_DB_TABLES: {
LOG_INFO("[Functions_FUNC_GET_DB_TABLES]");
ret = func_get_db_tables(req.msg.str, out, out_len);
break;
}
case Functions_FUNC_EXEC_DB_QUERY: {
LOG_INFO("[Functions_FUNC_EXEC_DB_QUERY]");
ret = func_exec_db_query(req.msg.query.db, req.msg.query.sql, out, out_len);
break;
}
default: {
LOG_ERROR("[UNKNOW FUNCTION]");
break;
@ -189,7 +264,8 @@ static int RunServer()
log_buffer(in, in_len);
if (dispatcher(in, in_len, gBuffer, &out_len)) {
log_buffer(gBuffer, out_len);
LOG_INFO("Send data length {}", out_len);
// log_buffer(gBuffer, out_len);
rv = nng_send(sock, gBuffer, out_len, 0);
if (rv != 0) {
LOG_ERROR("nng_send: {}", rv);
@ -199,7 +275,7 @@ static int RunServer()
// Error
LOG_ERROR("Dispatcher failed...");
rv = nng_send(sock, gBuffer, 0, 0);
break;
// break;
}
nng_free(in, in_len);
}