diff --git a/pywxdump/api/local_server.py b/pywxdump/api/local_server.py index 47a0b6f..279ac13 100644 --- a/pywxdump/api/local_server.py +++ b/pywxdump/api/local_server.py @@ -21,9 +21,6 @@ from pywxdump.api.rjson import ReJson, RqJson from pywxdump.api.utils import get_conf, get_conf_wxids, set_conf, error9999, gen_base64, validate_title, \ get_conf_local_wxid, ls_loger, random_str from pywxdump import get_wx_info, WX_OFFS, batch_decrypt, BiasAddr, merge_db, decrypt_merge, merge_real_time_db - -from pywxdump.db import DBHandler, download_file, export_csv, export_json - ls_api = Blueprint('ls_api', __name__, template_folder='../ui/web', static_folder='../ui/web/assets/', ) ls_api.debug = False diff --git a/pywxdump/api/remote_server.py b/pywxdump/api/remote_server.py index 225c150..f1418e6 100644 --- a/pywxdump/api/remote_server.py +++ b/pywxdump/api/remote_server.py @@ -24,7 +24,8 @@ from pywxdump.api.utils import get_conf, get_conf_wxids, set_conf, error9999, ge get_conf_local_wxid from pywxdump import get_wx_info, WX_OFFS, batch_decrypt, BiasAddr, merge_db, decrypt_merge, merge_real_time_db -from pywxdump.db import DBHandler, download_file, export_csv, export_json, dat2img +from pywxdump.db import DBHandler, download_file, dat2img +from pywxdump.db.export import export_csv, export_json # app = Flask(__name__, static_folder='../ui/web/dist', static_url_path='/') @@ -70,7 +71,7 @@ def user_session_list(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - db = DBHandler(db_config) + db = DBHandler(db_config, my_wxid=my_wxid) ret = db.get_session_list() return ReJson(0, list(ret.values())) @@ -85,7 +86,7 @@ def user_labels_dict(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - db = DBHandler(db_config) + db = DBHandler(db_config, my_wxid=my_wxid) user_labels_dict = db.get_labels() return ReJson(0, user_labels_dict) @@ -114,7 +115,7 @@ def user_list(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - db = DBHandler(db_config) + db = DBHandler(db_config, my_wxid=my_wxid) users = db.get_user(word, wxids, labels) return ReJson(0, users) @@ -203,12 +204,8 @@ def msg_count(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - db = DBHandler(db_config) - chat_count = db.get_msg_count(wxid) - chat_count1 = db.get_plc_msg_count(wxid) - # 合并两个字典,相同key,则将value相加 - count = {k: chat_count.get(k, 0) + chat_count1.get(k, 0) for k in - list(set(list(chat_count.keys()) + list(chat_count1.keys())))} + db = DBHandler(db_config, my_wxid=my_wxid) + count = db.get_msgs_count(wxid) return ReJson(0, count) @@ -234,11 +231,9 @@ def get_msgs(): if not isinstance(start, int) and not isinstance(limit, int): return ReJson(1002, body=f"start or limit is not int {start} {limit}") - db = DBHandler(db_config) - msgs, wxid_list = db.get_msgs(wxid=wxid, start_index=start, page_size=limit) - wxid_list.append(my_wxid) - user = db.get_user(wxids=wxid_list) - return ReJson(0, {"msg_list": msgs, "user_list": user}) + db = DBHandler(db_config, my_wxid=my_wxid) + msgs, users = db.get_msgs(wxid=wxid, start_index=start, page_size=limit) + return ReJson(0, {"msg_list": msgs, "user_list": users}) @rs_api.route('/api/rs/video/', methods=["GET", 'POST']) @@ -281,7 +276,7 @@ def get_audio(savePath): if not os.path.exists(os.path.dirname(savePath)): os.makedirs(os.path.dirname(savePath)) - db = DBHandler(db_config) + db = DBHandler(db_config, my_wxid=my_wxid) wave_data = db.get_audio(MsgSvrID, is_play=False, is_wave=True, save_path=savePath, rate=24000) if not wave_data: return ReJson(1001, body="wave_data is required") @@ -422,7 +417,7 @@ def get_export_csv(): if not os.path.exists(outpath): os.makedirs(outpath) - code, ret = export_csv(wxid, outpath, db_config) + code, ret = export_csv(wxid, outpath, db_config, my_wxid=my_wxid) if code: return ReJson(0, ret) else: @@ -447,7 +442,7 @@ def get_export_json(): if not os.path.exists(outpath): os.makedirs(outpath) - code, ret = export_json(wxid, outpath, db_config) + code, ret = export_json(wxid, outpath, db_config, my_wxid=my_wxid) if code: return ReJson(0, ret) else: @@ -475,7 +470,7 @@ def get_date_count(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - db = DBHandler(db_config) + db = DBHandler(db_config, my_wxid=my_wxid) date_count = db.get_date_count(wxid=word, start_time=start_time, end_time=end_time, time_format=time_format) return ReJson(0, date_count) @@ -496,7 +491,8 @@ def get_top_talker_count(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - date_count = DBHandler(db_config).get_top_talker_count(top=top, start_time=start_time, end_time=end_time) + date_count = DBHandler(db_config, my_wxid=my_wxid).get_top_talker_count(top=top, start_time=start_time, + end_time=end_time) return ReJson(0, date_count) @@ -519,7 +515,7 @@ def wordcloud(): my_wxid = get_conf(g.caf, g.at, "last") if not my_wxid: return ReJson(1001, body="my_wxid is required") db_config = get_conf(g.caf, my_wxid, "db_config") - db = DBHandler(db_config) + db = DBHandler(db_config, my_wxid=my_wxid) if target == "signature": users = db.get_user() diff --git a/pywxdump/db/__init__.py b/pywxdump/db/__init__.py index dcbd2d1..d1ccfab 100644 --- a/pywxdump/db/__init__.py +++ b/pywxdump/db/__init__.py @@ -18,16 +18,16 @@ from .dbPublicMsg import PublicMsgHandler from .dbOpenIMMedia import OpenIMMediaHandler from .dbSns import SnsHandler -from .export.exportCSV import export_csv -from .export.exportJSON import export_json - class DBHandler(MicroHandler, MediaHandler, OpenIMContactHandler, PublicMsgHandler, OpenIMMediaHandler, FavoriteHandler, SnsHandler): _class_name = "DBHandler" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, db_config, my_wxid, *args, **kwargs): + self.config = db_config + self.my_wxid = my_wxid + + super().__init__(self.config) # 加速查询索引 self.Micro_add_index() self.Msg_add_index() @@ -64,11 +64,21 @@ class DBHandler(MicroHandler, MediaHandler, OpenIMContactHandler, PublicMsgHandl msgs0, wxid_list0 = self.get_msg_list(wxid=wxid, start_index=start_index, page_size=page_size, msg_type=msg_type, msg_sub_type=msg_sub_type, start_createtime=start_createtime, - end_createtime=end_createtime) + end_createtime=end_createtime, my_talker=self.my_wxid) msgs1, wxid_list1 = self.get_plc_msg_list(wxid=wxid, start_index=start_index, page_size=page_size, msg_type=msg_type, msg_sub_type=msg_sub_type, start_createtime=start_createtime, - end_createtime=end_createtime) + end_createtime=end_createtime, my_talker=self.my_wxid) msgs = msgs0 + msgs1 wxid_list = wxid_list0 + wxid_list1 - return msgs, wxid_list + + users = self.get_user(wxids=wxid_list) + return msgs, users + + def get_msgs_count(self, wxids: list = ""): + chat_count = self.get_m_msg_count(wxids) + chat_count1 = self.get_plc_msg_count(wxids) + # 合并两个字典,相同key,则将value相加 + count = {k: chat_count.get(k, 0) + chat_count1.get(k, 0) for k in + list(set(list(chat_count.keys()) + list(chat_count1.keys())))} + return count diff --git a/pywxdump/db/dbMSG.py b/pywxdump/db/dbMSG.py index 2c31933..c3e9e22 100644 --- a/pywxdump/db/dbMSG.py +++ b/pywxdump/db/dbMSG.py @@ -35,13 +35,13 @@ class MsgHandler(DatabaseBase): self.execute("CREATE INDEX IF NOT EXISTS idx_MSG_StrTalker_CreateTime ON MSG(StrTalker, CreateTime);") @db_error - def get_msg_count(self, wxids: list = ""): + def get_m_msg_count(self, wxids: list = ""): """ 获取聊天记录数量,根据wxid获取单个联系人的聊天记录数量,不传wxid则获取所有联系人的聊天记录数量 :param wxids: wxid list :return: 聊天记录数量列表 {wxid: chat_count, total: total_count} """ - if isinstance(wxids, str): + if isinstance(wxids, str) and wxids: wxids = [wxids] if wxids: wxids = "('" + "','".join(wxids) + "')" @@ -67,7 +67,7 @@ class MsgHandler(DatabaseBase): @db_error def get_msg_list(self, wxid="", start_index=0, page_size=500, msg_type: str = "", msg_sub_type: str = "", - start_createtime=None, end_createtime=None): + start_createtime=None, end_createtime=None, my_talker="我"): """ 获取聊天记录列表 :param wxid: wxid @@ -77,6 +77,7 @@ class MsgHandler(DatabaseBase): :param msg_sub_type: 消息子类型 :param start_createtime: 开始时间 :param end_createtime: 结束时间 + :param my_talker: 我 :return: 聊天记录列表 {"id": _id, "MsgSvrID": str(MsgSvrID), "type_name": type_name, "is_sender": IsSender, "talker": talker, "room_name": StrTalker, "msg": msg, "src": src, "extra": {}, "CreateTime": CreateTime, } @@ -110,10 +111,9 @@ class MsgHandler(DatabaseBase): if not result: return [], [] - result_data = (self.get_msg_detail(row) for row in result) + result_data = (self.get_msg_detail(row, my_talker=my_talker) for row in result) rdata = list(result_data) # 转为列表 wxid_list = {d['talker'] for d in rdata} # 创建一个无重复的 wxid 列表 - return rdata, list(wxid_list) @db_error @@ -201,7 +201,7 @@ class MsgHandler(DatabaseBase): # 单条消息处理 @db_error - def get_msg_detail(self, row): + def get_msg_detail(self, row, my_talker="我"): """ 获取单条消息详情,格式化输出 """ @@ -365,7 +365,7 @@ class MsgHandler(DatabaseBase): talker = "未知" if IsSender == 1: - talker = "我" + talker = my_talker else: if StrTalker.endswith("@chatroom"): bytes_extra = get_BytesExtra(BytesExtra) diff --git a/pywxdump/db/dbPublicMsg.py b/pywxdump/db/dbPublicMsg.py index ca256d8..4b75cb6 100644 --- a/pywxdump/db/dbPublicMsg.py +++ b/pywxdump/db/dbPublicMsg.py @@ -36,7 +36,7 @@ class PublicMsgHandler(MsgHandler): """ if not self.tables_exist("PublicMsg"): return {} - if isinstance(wxids, str): + if isinstance(wxids, str) and wxids: wxids = [wxids] if wxids: wxids = "('" + "','".join(wxids) + "')" @@ -60,7 +60,7 @@ class PublicMsgHandler(MsgHandler): @db_error def get_plc_msg_list(self, wxid="", start_index=0, page_size=500, msg_type: str = "", msg_sub_type: str = "", - start_createtime=None, end_createtime=None): + start_createtime=None, end_createtime=None, my_talker="我"): """ 获取聊天记录列表 :param wxid: wxid @@ -103,7 +103,7 @@ class PublicMsgHandler(MsgHandler): if not result: return [], [] - result_data = (self.get_msg_detail(row) for row in result) + result_data = (self.get_msg_detail(row, my_talker=my_talker) for row in result) rdata = list(result_data) # 转为列表 wxid_list = {d['talker'] for d in rdata} # 创建一个无重复的 wxid 列表 diff --git a/pywxdump/db/export/__init__.py b/pywxdump/db/export/__init__.py index 5fcb078..b5c7ad4 100644 --- a/pywxdump/db/export/__init__.py +++ b/pywxdump/db/export/__init__.py @@ -5,7 +5,5 @@ # Author: xaoyaoo # Date: 2024/04/20 # ------------------------------------------------------------------------------- - - -if __name__ == '__main__': - pass +from .exportCSV import export_csv +from .exportJSON import export_json \ No newline at end of file diff --git a/pywxdump/db/export/exportCSV.py b/pywxdump/db/export/exportCSV.py index 2ef4258..1bacc0c 100644 --- a/pywxdump/db/export/exportCSV.py +++ b/pywxdump/db/export/exportCSV.py @@ -8,18 +8,19 @@ import csv import json import os -from ..dbMSG import MsgHandler + +from pywxdump import DBHandler -def export_csv(wxid, outpath, db_config, page_size=5000): +def export_csv(wxid, outpath, db_config, my_wxid="我", page_size=5000): if not os.path.exists(outpath): outpath = os.path.join(os.getcwd(), "export" + os.sep + wxid) if not os.path.exists(outpath): os.makedirs(outpath) - pmsg = MsgHandler(db_config) + db = DBHandler(db_config, my_wxid) - count = pmsg.get_msg_count(wxid) + count = db.get_msgs_count(wxid) chatCount = count.get(wxid, 0) if chatCount == 0: return False, "没有聊天记录" @@ -27,9 +28,11 @@ def export_csv(wxid, outpath, db_config, page_size=5000): if page_size > chatCount: page_size = chatCount + 1 + users = {} for i in range(0, chatCount, page_size): start_index = i - data, wxid_list = pmsg.get_msg_list(wxid, start_index, page_size) + data, users_t = db.get_msg_list(wxid, start_index, page_size) + users.update(users_t) if len(data) == 0: return False, "没有聊天记录" @@ -52,7 +55,8 @@ def export_csv(wxid, outpath, db_config, page_size=5000): src = row.get("src", "") CreateTime = row.get("CreateTime", "") csv_writer.writerow([id, MsgSvrID, type_name, is_sender, talker, room_name, msg, src, CreateTime]) - + with open(os.path.join(outpath, "users.json"), "w", encoding="utf-8") as f: + json.dump(users, f, ensure_ascii=False, indent=4) return True, f"导出成功: {outpath}" diff --git a/pywxdump/db/export/exportJSON.py b/pywxdump/db/export/exportJSON.py index 68942bc..29e9023 100644 --- a/pywxdump/db/export/exportJSON.py +++ b/pywxdump/db/export/exportJSON.py @@ -7,31 +7,35 @@ # ------------------------------------------------------------------------------- import json import os -from ..dbMSG import MsgHandler +from ..__init__ import DBHandler -def export_json(wxid, outpath, db_config): +def export_json(wxid, outpath, db_config, my_wxid="我", indent=4): if not os.path.exists(outpath): outpath = os.path.join(os.getcwd(), "export" + os.sep + wxid) if not os.path.exists(outpath): os.makedirs(outpath) - pmsg = MsgHandler(db_config) + db = DBHandler(db_config, my_wxid) - count = pmsg.get_msg_count(wxid) + count = db.get_msgs_count(wxid) chatCount = count.get(wxid, 0) if chatCount == 0: return False, "没有聊天记录" - + users = {} page_size = chatCount + 1 for i in range(0, chatCount, page_size): start_index = i - data, wxid_list = pmsg.get_msg_list(wxid, start_index, page_size) + data, users_t = db.get_msgs(wxid, start_index, page_size) + users.update(users_t) if len(data) == 0: return False, "没有聊天记录" + save_path = os.path.join(outpath, f"{wxid}_{i}_{i + page_size}.json") with open(save_path, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=4) + json.dump(data, f, ensure_ascii=False, indent=indent) + with open(os.path.join(outpath, "users.json"), "w", encoding="utf-8") as f: + json.dump(users, f, ensure_ascii=False, indent=indent) return True, f"导出成功: {outpath}"