From 4de1c26f250289ffa4f50aa59d6bbe92a6960268 Mon Sep 17 00:00:00 2001 From: cllcode <2440893398@qq.com> Date: Sat, 13 Jul 2024 21:56:42 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=A2=9E=E5=8A=A0=E8=8E=B7=E5=8F=96=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E5=A4=A7=E5=B0=8F=E3=80=81=E6=96=87=E4=BB=B6=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E6=96=B9=E6=B3=95=EF=BC=8C=E5=B9=B6=E8=A1=A5=E5=85=85?= =?UTF-8?q?=E6=B3=A8=E9=87=8A=202.=E5=A2=9E=E5=8A=A0=E5=AF=B9=E8=B1=A1?= =?UTF-8?q?=E5=AD=98=E5=82=A8=E7=9A=84=E9=85=8D=E7=BD=AE=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pywxdump/api/api.py | 87 ++++-- pywxdump/cli.py | 77 +++++- pywxdump/common/__init__.py | 4 + .../common/config/oss_config/s3_config.py | 58 ++++ .../config/oss_config/storage_config.py | 60 +++++ .../oss_config/storage_config_factory.py | 25 ++ pywxdump/common/config/oss_config_manager.py | 39 +++ pywxdump/common/config/server_config.py | 97 +++++++ pywxdump/dbpreprocess/dbbase.py | 31 ++- pywxdump/file/Attachment.py | 123 +++++++++ pywxdump/file/AttachmentAbstract.py | 26 -- pywxdump/file/AttachmentContext.py | 130 ++++++++- pywxdump/file/ConfigurableAttachment.py | 12 + pywxdump/file/LocalAttachment.py | 161 +++++++++++- pywxdump/file/S3Attachment.py | 248 ++++++++++++++---- pywxdump/file/__init__.py | 4 + pywxdump/server.py | 47 ++-- pywxdump/wx_info/get_wx_info.py | 2 +- pywxdump/wx_info/merge_db.py | 4 +- requirements.txt | 3 +- tests/file/TestLocalAttachment.py | 63 +++++ tests/file/TestS3Attachment.py | 71 +++++ 22 files changed, 1222 insertions(+), 150 deletions(-) create mode 100644 pywxdump/common/__init__.py create mode 100644 pywxdump/common/config/oss_config/s3_config.py create mode 100644 pywxdump/common/config/oss_config/storage_config.py create mode 100644 pywxdump/common/config/oss_config/storage_config_factory.py create mode 100644 pywxdump/common/config/oss_config_manager.py create mode 100644 pywxdump/common/config/server_config.py create mode 100644 pywxdump/file/Attachment.py delete mode 100644 pywxdump/file/AttachmentAbstract.py create mode 100644 pywxdump/file/ConfigurableAttachment.py create mode 100644 pywxdump/file/__init__.py create mode 100644 tests/file/TestLocalAttachment.py create mode 100644 tests/file/TestS3Attachment.py diff --git a/pywxdump/api/api.py b/pywxdump/api/api.py index 70f4ccd..6f0c687 100644 --- a/pywxdump/api/api.py +++ b/pywxdump/api/api.py @@ -5,14 +5,14 @@ # Author: xaoyaoo # Date: 2024/01/02 # ------------------------------------------------------------------------------- -import base64 -import json import logging -import os -import re import time import shutil import sys + +from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory +from pywxdump.common.config.oss_config_manager import OSSConfigManager + pythoncom = __import__('pythoncom') if sys.platform == "win32" else None import pywxdump from pywxdump.file import AttachmentContext @@ -55,8 +55,8 @@ def init_last(): 是否初始化 :return: """ - my_wxid = request.json.get("my_wxid", "") - my_wxid = my_wxid.strip().strip("'").strip('"') if isinstance(my_wxid, str) else "" + my_wxid_dict = request.json.get("my_wxid", {}) + my_wxid = my_wxid_dict.get("wxid", "") if not my_wxid: my_wxid = read_session(g.sf, "test", "last") if my_wxid: @@ -64,6 +64,9 @@ def init_last(): merge_path = read_session(g.sf, my_wxid, "merge_path") wx_path = read_session(g.sf, my_wxid, "wx_path") key = read_session(g.sf, my_wxid, "key") + # 如果有oss_config则设置对象存储配置 + oss_config = read_session(g.sf, my_wxid, "oss_config") + ossConfig(oss_config) rdata = { "merge_path": merge_path, "wx_path": wx_path, @@ -96,11 +99,15 @@ def init_key(): return ReJson(1002, body=f"my_wxid is required: {my_wxid}") old_merge_save_path = read_session(g.sf, my_wxid, "merge_path") - if isinstance(old_merge_save_path, str) and old_merge_save_path and AttachmentContext.exists(old_merge_save_path): - pmsg = ParsingMSG(old_merge_save_path) - pmsg.close_all_connection() + # 如果有oss_config则设置对象存储配置 + oss_config = read_session(g.sf, my_wxid, "oss_config") + ossConfig(oss_config) + # 如果存在旧地连接则关闭连接 + if isinstance(old_merge_save_path, str) and old_merge_save_path: + ParsingMSG.terminate_connection(old_merge_save_path) - out_path = AttachmentContext.join(g.tmp_path, "decrypted", my_wxid) if my_wxid else AttachmentContext.join(g.tmp_path, "decrypted") + out_path = AttachmentContext.join(g.tmp_path, "decrypted", my_wxid) if my_wxid else AttachmentContext.join( + g.tmp_path, "decrypted") # 检查文件夹中文件是否被占用 if AttachmentContext.exists(out_path): try: @@ -115,7 +122,7 @@ def init_key(): if code: # 移动merge_save_path到g.tmp_path/my_wxid if not AttachmentContext.exists(AttachmentContext.join(g.tmp_path, my_wxid)): - os.makedirs(AttachmentContext.join(g.tmp_path, my_wxid)) + AttachmentContext.makedirs(AttachmentContext.join(g.tmp_path, my_wxid)) merge_save_path_new = AttachmentContext.join(g.tmp_path, my_wxid, "merge_all.db") shutil.move(merge_save_path, str(merge_save_path_new)) @@ -144,6 +151,19 @@ def init_key(): return ReJson(2001, body=merge_save_path) +def ossConfig(oss_config: str): + """ + 设置对象存储配置 + + :param oss_config: 对象存储配置 + + :return: None + """ + if oss_config: + storageConfig = StorageConfigFactory.create(oss_config) + OSSConfigManager().load_config(storageConfig) + + @api.route('/api/init_nokey', methods=["GET", 'POST']) @error9999 def init_nokey(): @@ -154,6 +174,9 @@ def init_nokey(): merge_path = request.json.get("merge_path", "").strip().strip("'").strip('"') wx_path = request.json.get("wx_path", "").strip().strip("'").strip('"') my_wxid = request.json.get("my_wxid", "").strip().strip("'").strip('"') + # 如果有oss_config则设置对象存储配置 + oss_config = request.json.get("oss_config", "").strip().strip("'").strip('"') + ossConfig(oss_config) if not wx_path: return ReJson(1002, body=f"wx_path is required: {wx_path}") @@ -170,6 +193,7 @@ def init_nokey(): save_session(g.sf, my_wxid, "wx_path", wx_path) save_session(g.sf, my_wxid, "key", key) save_session(g.sf, my_wxid, "my_wxid", my_wxid) + save_session(g.sf, my_wxid, "oss_config", oss_config) save_session(g.sf, "test", "last", my_wxid) rdata = { "merge_path": merge_path, @@ -181,6 +205,25 @@ def init_nokey(): return ReJson(0, rdata) +# 查询支持的对象存储,及配置 +@api.route('/api/check_storage_type', methods=["GET", 'POST']) +@error9999 +def check_storage_type(): + path = request.json.get("path", "").strip().strip("'").strip('"') + if not path: + return ReJson(1002, body=f"path is required: {path}") + + selectStorageConfig = None + for key, storageConfig in StorageConfigFactory.registry.items(): + if storageConfig.isSupported(path): + selectStorageConfig = storageConfig + if selectStorageConfig: + return ReJson(0, {"is_supported": True, "storage_type": selectStorageConfig.type(), + "config_items": selectStorageConfig.describe()}) + else: + return ReJson(0, {"is_supported": False}) + + # END 以上为初始化相关 *************************************************************************************************** @@ -338,7 +381,7 @@ def get_imgsrc(imgsrc): img_tmp_path = AttachmentContext.join(g.tmp_path, my_wxid, "imgsrc") if not AttachmentContext.exists(img_tmp_path): - os.makedirs(img_tmp_path) + AttachmentContext.makedirs(img_tmp_path) file_name = imgsrc.replace("http://", "").replace("https://", "").replace("/", "_").replace("?", "_") file_name = file_name + ".jpg" # 如果文件名过长,则将文件明分为目录和文件名 @@ -434,10 +477,10 @@ def get_video(videoPath): # 复制文件到临时文件夹 video_save_path = AttachmentContext.join(video_tmp_path, videoPath) if not AttachmentContext.exists(AttachmentContext.dirname(video_save_path)): - os.makedirs(AttachmentContext.dirname(video_save_path)) + AttachmentContext.makedirs(AttachmentContext.dirname(video_save_path)) if AttachmentContext.exists(video_save_path): return AttachmentContext.send_attachment(video_save_path) - AttachmentContext.download_file(original_img_path,video_save_path) + AttachmentContext.download_file(original_img_path, video_save_path) return AttachmentContext.send_attachment(original_img_path) @@ -457,7 +500,7 @@ def get_audio(savePath): # 判断savePath路径的文件夹是否存在 if not AttachmentContext.exists(AttachmentContext.dirname(savePath)): - os.makedirs(AttachmentContext.dirname(savePath)) + AttachmentContext.makedirs(AttachmentContext.dirname(savePath)) parsing_media_msg = ParsingMediaMSG(merge_path) wave_data = parsing_media_msg.get_audio(MsgSvrID, is_play=False, is_wave=True, save_path=savePath, rate=24000) @@ -484,8 +527,8 @@ def get_file_info(): all_file_path = AttachmentContext.join(wx_path, file_path) if not AttachmentContext.exists(all_file_path): return ReJson(5002) - file_name = os.path.basename(all_file_path) - file_size = os.path.getsize(all_file_path) + file_name = AttachmentContext.basename(all_file_path) + file_size = AttachmentContext.getsize(all_file_path) return ReJson(0, {"file_name": file_name, "file_size": str(file_size)}) @@ -528,12 +571,12 @@ def get_export_endb(): outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "endb") if not AttachmentContext.exists(outpath): - os.makedirs(outpath) + AttachmentContext.makedirs(outpath) for wxdb in wxdbpaths: # 复制wxdb->outpath, os.path.basename(wxdb) assert isinstance(outpath, str) # 为了解决pycharm的警告, 无实际意义 - shutil.copy(wxdb, AttachmentContext.join(outpath, os.path.basename(wxdb))) + shutil.copy(wxdb, AttachmentContext.join(outpath, AttachmentContext.basename(wxdb))) return ReJson(0, body=outpath) @@ -558,7 +601,7 @@ def get_export_dedb(): outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "dedb") if not AttachmentContext.exists(outpath): - os.makedirs(outpath) + AttachmentContext.makedirs(outpath) code, merge_save_path = decrypt_merge(wx_path=wx_path, key=key, outpath=outpath) time.sleep(1) @@ -589,7 +632,7 @@ def get_export_csv(): outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "csv", wxid) if not AttachmentContext.exists(outpath): - os.makedirs(outpath) + AttachmentContext.makedirs(outpath) code, ret = export_csv(wxid, outpath, read_session(g.sf, my_wxid, "merge_path")) if code: @@ -613,7 +656,7 @@ def get_export_json(): outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "json", wxid) if not AttachmentContext.exists(outpath): - os.makedirs(outpath) + AttachmentContext.makedirs(outpath) code, ret = export_json(wxid, outpath, read_session(g.sf, my_wxid, "merge_path")) if code: diff --git a/pywxdump/cli.py b/pywxdump/cli.py index 7f44d94..8e64811 100644 --- a/pywxdump/cli.py +++ b/pywxdump/cli.py @@ -6,11 +6,18 @@ # Date: 2023/10/14 # ------------------------------------------------------------------------------- import argparse +import json +import os import sys -import time from pywxdump import * import pywxdump +from pywxdump.common.config.oss_config.storage_config import DescriptionBuilder +from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory +from pywxdump.common.config.oss_config_manager import OSSConfigManager +from pywxdump.common.config.server_config import ServerConfig +from pywxdump.common.constants import VERSION_LIST_PATH +from pywxdump.file import AttachmentContext wxdump_ascii = r""" ██████╗ ██╗ ██╗██╗ ██╗██╗ ██╗██████╗ ██╗ ██╗███╗ ███╗██████╗ @@ -264,31 +271,63 @@ class MainShowChatRecords(BaseSubMainClass): default="", metavar="") parser.add_argument("--online", action='store_true', help="(可选)是否在线查看(局域网查看)", required=False, default=False) + for key, storageConfig in StorageConfigFactory.registry.items(): + storageConfigDescribe = storageConfig.describe() + for description in storageConfigDescribe: + parser.add_argument( + f"--{key}_{description[DescriptionBuilder.KEY]}", + help=description[DescriptionBuilder.PLACEHOLDER], + default=False, required=False, type=str, metavar="" + ) # parser.add_argument("-k", "--key", type=str, help="(可选)密钥", required=False, metavar="") return parser def run(self, args): print(f"[*] PyWxDump v{pywxdump.__version__}") + + server_config = ServerConfig.builder() + server_config.merge_path(args.merge_path) + server_config.wx_path(args.wx_path) + server_config.my_wxid(args.my_wxid) + server_config.online(args.online) + server_config.port(9000) + + for key, storageConfig in StorageConfigFactory.registry.items(): + storageConfigDescribe = storageConfig.describe() + config = {} + for description in storageConfigDescribe: + value = getattr(args, f"{key}_{description[DescriptionBuilder.KEY]}") + if value: + config[description[DescriptionBuilder.KEY]] = value + if config: + oss_config = storageConfig.value_of(config) + server_config.oss_config(oss_config) + # 加载对象存储配置 + OSSConfigManager().load_config(oss_config) + start_config = server_config.build() + + # 参数检查 + if not self._dbshow_parameter_check(start_config): + return + + start_falsk(start_config) + + def _dbshow_parameter_check(self, server_config) -> bool: # (merge)和(msg_path,micro_path,media_path) 二选一 # if not args.merge_path and not (args.msg_path and args.micro_path and args.media_path): # print("[-] 请输入数据库路径([merge_path] or [msg_path, micro_path, media_path])") # return # 目前仅能支持merge database - if not args.merge_path: + if not server_config.merge_path: print("[-] 请输入数据库路径([merge_path])") - return + return False # 从命令行参数获取值 - merge_path = args.merge_path - - online = args.online - - if not os.path.exists(merge_path): + if not AttachmentContext.exists(server_config.merge_path): print("[-] 输入数据库路径不存在") - return - - start_falsk(merge_path=merge_path, wx_path=args.wx_path, key="", my_wxid=args.my_wxid, online=online) + return False + return True class MainExportChatRecords(BaseSubMainClass): @@ -338,7 +377,13 @@ class MainUi(BaseSubMainClass): debug = args.debug isopenBrowser = args.isOpenBrowser - start_falsk(port=port, online=online, debug=debug, isopenBrowser=isopenBrowser) + server_config = ServerConfig.builder() + server_config.is_open_browser(isopenBrowser) + server_config.debug(debug) + server_config.port(port) + server_config.online(online) + + start_falsk(server_config.build()) class MainApi(BaseSubMainClass): @@ -359,7 +404,13 @@ class MainApi(BaseSubMainClass): port = args.port debug = args.debug - start_falsk(port=port, online=online, debug=debug, isopenBrowser=False) + server_config = ServerConfig.builder() + server_config.debug(debug) + server_config.port(port) + server_config.is_open_browser(False) + server_config.online(online) + + start_falsk(server_config.build()) def console_run(): diff --git a/pywxdump/common/__init__.py b/pywxdump/common/__init__.py new file mode 100644 index 0000000..36213f2 --- /dev/null +++ b/pywxdump/common/__init__.py @@ -0,0 +1,4 @@ + + +if __name__ == '__main__': + pass diff --git a/pywxdump/common/config/oss_config/s3_config.py b/pywxdump/common/config/oss_config/s3_config.py new file mode 100644 index 0000000..485b97f --- /dev/null +++ b/pywxdump/common/config/oss_config/s3_config.py @@ -0,0 +1,58 @@ +from pywxdump.common.config.oss_config.storage_config import StorageConfig, TYPE_KEY, DescriptionBuilder +from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory + +S3 = "s3" +ENDPOINT_URL = "endpoint_url" +SECRET_KEY = "secret_key" +ACCESS_KEY = "access_key" + + +@StorageConfigFactory.register(S3) +class S3Config(StorageConfig): + + def __init__(self, access_key, secret_key, endpoint_url): + self.access_key = access_key + self.secret_key = secret_key + self.endpoint_url = endpoint_url + + @classmethod + def type(cls) -> str: + return S3 + + @classmethod + def describe(cls): + builder = DescriptionBuilder() + builder.add_description(ENDPOINT_URL, "https://cos..myqcloud.com", ENDPOINT_URL) + builder.add_description(ACCESS_KEY, "腾讯云的SecretId", ACCESS_KEY) + builder.add_description(SECRET_KEY, "腾讯云的SecretKey", SECRET_KEY) + return builder.build() + + def get_config(self): + return { + TYPE_KEY: S3, + ACCESS_KEY: self.access_key, + SECRET_KEY: self.secret_key, + ENDPOINT_URL: self.endpoint_url + } + + def validate_config(self): + if not self.access_key or not self.secret_key or not self.endpoint_url: + raise ValueError("S3 configuration is not valid") + + @classmethod + def value_of(cls, config: dict): + access_key = config.get(ACCESS_KEY) + secret_key = config.get(SECRET_KEY) + endpoint_url = config.get(ENDPOINT_URL) + s3_config = cls(access_key, secret_key, endpoint_url) + s3_config.validate_config() + return s3_config + + @classmethod + def isSupported(cls, path: str) -> bool: + if not path: + return False + if path.startswith(f"{S3}://"): + return True + + diff --git a/pywxdump/common/config/oss_config/storage_config.py b/pywxdump/common/config/oss_config/storage_config.py new file mode 100644 index 0000000..fe25d18 --- /dev/null +++ b/pywxdump/common/config/oss_config/storage_config.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod + + + +TYPE_KEY = "oss_type_key" + + +class StorageConfig(ABC): + + @classmethod + @abstractmethod + def type(cls): + """返回实现类的类型""" + pass + + @classmethod + @abstractmethod + def describe(cls): + """返回对象存储的描述""" + pass + + @abstractmethod + def get_config(self): + """返回存储配置的字典""" + pass + + @abstractmethod + def validate_config(self): + """验证配置是否合法""" + pass + + @classmethod + @abstractmethod + def value_of(cls, config: dict): + pass + + @classmethod + @abstractmethod + def isSupported(cls, path: str) -> bool: + pass + + +class DescriptionBuilder: + LABEL = "label" + PLACEHOLDER = "placeholder" + KEY = "key" + + def __init__(self): + self.description = [] + + def add_description(self, label, placeholder, key): + self.description.append({ + self.LABEL: label, + self.PLACEHOLDER: placeholder, + self.KEY: key + }) + return self + + def build(self): + return self.description diff --git a/pywxdump/common/config/oss_config/storage_config_factory.py b/pywxdump/common/config/oss_config/storage_config_factory.py new file mode 100644 index 0000000..cd62214 --- /dev/null +++ b/pywxdump/common/config/oss_config/storage_config_factory.py @@ -0,0 +1,25 @@ +import json +from abc import ABC + +from pywxdump.common.config.oss_config.storage_config import TYPE_KEY, StorageConfig + + +class StorageConfigFactory(ABC): + registry = {} + + @classmethod + def register(cls, type_name): + def inner_wrapper(subclass): + cls.registry[type_name] = subclass + return subclass + + return inner_wrapper + + @staticmethod + def create(json_str) -> StorageConfig: + config_dict = json.loads(json_str) + config_type = config_dict.get(TYPE_KEY) + subclass = StorageConfigFactory.registry.get(config_type) + if subclass is None: + raise ValueError(f'Unknown config type: {config_type}') + return subclass.value_of(config_dict) diff --git a/pywxdump/common/config/oss_config_manager.py b/pywxdump/common/config/oss_config_manager.py new file mode 100644 index 0000000..5a6caa8 --- /dev/null +++ b/pywxdump/common/config/oss_config_manager.py @@ -0,0 +1,39 @@ +from typing import Type + +from pywxdump.common.config.oss_config.storage_config import StorageConfig +from pywxdump.file.ConfigurableAttachment import ConfigurableAttachment + + +def singleton(cls): + instances = {} + + def get_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return get_instance + + +@singleton +class OSSConfigManager: + def __init__(self): + self._config_instances = {} + self._attachment_instance = {} + + def load_config(self, config: StorageConfig): + config.validate_config() + self._config_instances[config.type()] = config + # 清除旧的实例 + self._attachment_instance[config.type()] = None + + def get_config(self, config_type: str) -> StorageConfig: + return self._config_instances.get(config_type) + + def get_attachment(self, config_type: str, instance_class: Type[ConfigurableAttachment]): + if config_type not in self._config_instances: + raise ValueError(f"Config not found: {config_type}") + if not self._attachment_instance[config_type]: + config = self._config_instances[config_type] + self._attachment_instance[config_type] = instance_class.load_config(config) + return self._attachment_instance[config_type] diff --git a/pywxdump/common/config/server_config.py b/pywxdump/common/config/server_config.py new file mode 100644 index 0000000..9cc1956 --- /dev/null +++ b/pywxdump/common/config/server_config.py @@ -0,0 +1,97 @@ +import json +from dataclasses import dataclass + +from pywxdump.common.config.oss_config.storage_config import StorageConfig + + +@dataclass +class ServerConfig: + """ + :param merge_path: 合并后的数据库路径 默认"" + :param wx_path: 微信文件夹的路径(用于显示图片) 默认"" + :param key: 密钥 默认"" + :param my_wxid: 微信账号(本人微信id) 默认"" + :param port: 端口号 默认5000 + :param online: 是否在线查看(局域网查看) 默认 False + :param debug: 是否开启debug模式 默认 False + :param is_open_browser: 是否自动打开浏览器 默认 True + :param oss_config: 对象存储配置 默认 None + """ + merge_path: str = "" + wx_path: str = "" + key: str = "" + my_wxid: str = "" + port: int = 5000 + online: bool = False + debug: bool = False + is_open_browser: bool = True + oss_config: dict = None + + @classmethod + def builder(cls): + return ServerConfig.Builder() + + class Builder: + def __init__(self): + self._merge_path = "" + self._wx_path = "" + self._key = "" + self._my_wxid = "" + self._port = 5000 + self._online = False + self._debug = False + self._is_open_browser = True + self._oss_config = None + + def merge_path(self, merge_path: str): + self._merge_path = merge_path + return self + + def wx_path(self, wx_path: str): + self._wx_path = wx_path + return self + + def key(self, key: str): + self._key = key + return self + + def my_wxid(self, my_wxid: str): + self._my_wxid = my_wxid + return self + + def port(self, port: int): + self._port = port + return self + + def online(self, online: bool): + self._online = online + return self + + def debug(self, debug: bool): + self._debug = debug + return self + + def is_open_browser(self, is_open_browser: bool): + self._is_open_browser = is_open_browser + return self + + def oss_config(self, oss_config: StorageConfig): + oss_config.validate_config() + self._oss_config = oss_config.get_config() + return self + + def build(self): + return ServerConfig( + merge_path=self._merge_path, + wx_path=self._wx_path, + key=self._key, + my_wxid=self._my_wxid, + port=self._port, + online=self._online, + debug=self._debug, + is_open_browser=self._is_open_browser, + oss_config=self._oss_config + ) + + def oss_config_to_json(self) -> str: + return json.dumps(self.oss_config) if self.oss_config else None diff --git a/pywxdump/dbpreprocess/dbbase.py b/pywxdump/dbpreprocess/dbbase.py index 40b6ef9..b334057 100644 --- a/pywxdump/dbpreprocess/dbbase.py +++ b/pywxdump/dbpreprocess/dbbase.py @@ -38,6 +38,7 @@ class DatabaseBase: if not AttachmentContext.isLocalPath(db_path): temp_dir = tempfile.gettempdir() local_path = os.path.join(temp_dir, f"{uuid.uuid1()}.db") + logging.info(f"下载文件到本地: {db_path} -> {local_path}") AttachmentContext.download_file(db_path, local_path) else: local_path = db_path @@ -88,7 +89,23 @@ class DatabaseBase: self._db_connection = None if not AttachmentContext.isLocalPath(self._db_path): # 删除tmp目录下的db文件 - self.clearTmpDb() + self._clearTmpDb() + + @classmethod + def terminate_connection(cls, db_path: str): + """ + 关闭数据库连接 + + :param db_path: 数据库路径 + + """ + if db_path in cls._connection_pool and cls._connection_pool[db_path]: + cls._connection_pool[db_path].close() + logging.info(f"关闭数据库 - {db_path}") + cls._connection_pool[db_path] = None + if not AttachmentContext.isLocalPath(db_path): + # 删除tmp目录下的db文件 + cls._clearTmpDb() def close_all_connection(self): for db_path in self._connection_pool: @@ -97,18 +114,18 @@ class DatabaseBase: logging.info(f"关闭数据库 - {db_path}") self._connection_pool[db_path] = None # 删除tmp目录下的db文件 - self.clearTmpDb() - def clearTmpDb(self): - # 清理 tmp目录下.db文件 + self._clearTmpDb() + + @staticmethod + def _clearTmpDb(): + # 清理 tmp目录下.db文件,如果db文件存储在对象存储中,使用时需要下载到tmp目录中,且文件名是由uuid生成,为了避免文件过多,需要清理 temp_dir = tempfile.gettempdir() db_files = glob.glob(os.path.join(temp_dir, '*.db')) for db_file in db_files: try: os.remove(db_file) - print(f"Deleted: {db_file}") except Exception as e: - print(f"Error deleting {db_file}: {e}") - + logging.error(f"Error deleting {db_file}: {e}") def show__singleton_instances(self): print(self._singleton_instances) diff --git a/pywxdump/file/Attachment.py b/pywxdump/file/Attachment.py new file mode 100644 index 0000000..489ad70 --- /dev/null +++ b/pywxdump/file/Attachment.py @@ -0,0 +1,123 @@ +from typing import Protocol, IO + + +# 基类 +class Attachment(Protocol): + """ + 附件处理协议类,定义了附件处理的基本接口。 + """ + + def exists(self, path: str) -> bool: + """ + 检查文件或目录是否存在。 + + 参数: + path (str): 文件或目录路径。 + + 返回: + bool: 如果存在返回True,否则返回False。 + """ + pass + + def makedirs(self, path: str) -> bool: + """ + 创建目录,包括所有中间目录。 + + 参数: + path (str): 目录路径。 + + 返回: + bool: 总是返回True。 + """ + pass + + def open(self, path: str, mode: str) -> IO: + """ + 打开一个文件并返回文件对象。 + + 参数: + path (str): 文件路径。 + mode (str): 打开文件的模式。 + + 返回: + IO: 文件对象。 + """ + pass + + def remove(self, path: str) -> bool: + """ + 删除文件 + + 参数: + path (str): 文件路径 + + 返回: + bool: 是否删除成功 + """ + pass + + def isdir(self, path: str) -> bool: + + """ + 判断是否为目录 + + 参数: + s3_url (str): 文件路径 + + 返回: + bool: 是否为目录 + """ + pass + @classmethod + def join(cls, path: str, *paths: str) -> str: + """ + 连接一个或多个路径组件。 + + 参数: + path (str): 第一个路径组件。 + *paths (str): 其他路径组件。 + + 返回: + str: 连接后的路径。 + """ + pass + + @classmethod + def dirname(cls, path: str) -> str: + """ + 获取路径的目录名。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 目录名。 + """ + pass + + @classmethod + def basename(cls, path: str) -> str: + """ + 获取路径的基本名(文件名)。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 基本名(文件名)。 + """ + pass + + def getsize(self, path) -> int: + """ + 获取文件大小 + + 参数: + path (str): 文件路径 + + 返回: + int: 文件大小 + """ + pass + + diff --git a/pywxdump/file/AttachmentAbstract.py b/pywxdump/file/AttachmentAbstract.py deleted file mode 100644 index 811202a..0000000 --- a/pywxdump/file/AttachmentAbstract.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Protocol, IO - - -# 基类 -class Attachment(Protocol): - - def exists(self, path) -> bool: - pass - - def makedirs(self, path) -> bool: - pass - - def open(self, path, param) -> IO: - pass - - @classmethod - def join(cls, __a: str, *paths: str) -> str: - pass - - @classmethod - def dirname(cls, path: str) -> str: - pass - - @classmethod - def basename(cls, path: str) -> str: - pass diff --git a/pywxdump/file/AttachmentContext.py b/pywxdump/file/AttachmentContext.py index 3556c2a..a1f3b39 100644 --- a/pywxdump/file/AttachmentContext.py +++ b/pywxdump/file/AttachmentContext.py @@ -1,41 +1,108 @@ import os from datetime import datetime -from typing import AnyStr, BinaryIO, Callable, Union, IO +from typing import AnyStr, Callable, Union, IO from flask import send_file, Response -from pywxdump.file.AttachmentAbstract import Attachment +from pywxdump.common.config.oss_config.s3_config import S3Config +from pywxdump.common.config.oss_config_manager import OSSConfigManager +from pywxdump.file.Attachment import Attachment from pywxdump.file.LocalAttachment import LocalAttachment from pywxdump.file.S3Attachment import S3Attachment def determine_strategy(file_path: str) -> Attachment: - if file_path.startswith("s3://"): - return S3Attachment() + """ + 根据文件路径确定使用的附件策略(本地或S3)。 + + 参数: + file_path (str): 文件路径。 + + 返回: + Attachment: 返回对应的附件策略类实例。 + """ + if file_path.startswith(f"s3://"): + return OSSConfigManager().get_attachment("s3", S3Attachment) else: return LocalAttachment() def exists(path: str) -> bool: + """ + 检查文件或目录是否存在。 + + 参数: + path (str): 文件或目录路径。 + + 返回: + bool: 如果存在返回True,否则返回False。 + """ return determine_strategy(path).exists(path) def open_file(path: str, mode: str) -> IO: + """ + 打开一个文件并返回文件对象。 + + 参数: + path (str): 文件路径。 + mode (str): 打开文件的模式。 + + 返回: + IO: 文件对象。 + """ return determine_strategy(path).open(path, mode) def makedirs(path: str) -> bool: + """ + 创建目录,包括所有中间目录。 + + 参数: + path (str): 目录路径。 + + 返回: + bool: 总是返回True。 + """ return determine_strategy(path).makedirs(path) -def join(__a: str, *paths: str) -> str: - return determine_strategy(__a).join(__a, *paths) +def join(path: str, *paths: str) -> str: + """ + 连接一个或多个路径组件。 + + 参数: + path (str): 第一个路径组件。 + *paths (str): 其他路径组件。 + + 返回: + str: 连接后的路径。 + """ + return determine_strategy(path).join(path, *paths) def dirname(path: str) -> str: + """ + 获取路径的目录名。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 目录名。 + """ return determine_strategy(path).dirname(path) def basename(path: str) -> str: + """ + 获取路径的基本名(文件名)。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 基本名(文件名)。 + """ return determine_strategy(path).basename(path) @@ -49,6 +116,22 @@ def send_attachment( last_modified: Union[datetime, int, float, None] = None, max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None, ) -> Response: + """ + 发送附件文件。 + + 参数: + path_or_file (Union[os.PathLike[AnyStr], str]): 文件路径或文件对象。 + mimetype (Union[str, None]): 文件的MIME类型。 + as_attachment (bool): 是否作为附件下载。 + download_name (Union[str, None]): 下载时的文件名。 + conditional (bool): 是否使用条件请求。 + etag (Union[bool, str]): ETag值。 + last_modified (Union[datetime, int, float, None]): 最后修改时间。 + max_age (Union[None, int, Callable[[Union[str, None]], Union[int, None]]]): 缓存最大时间。 + + 返回: + Response: Flask的响应对象。 + """ file_io = open_file(path_or_file, "rb") # 如果没有提供 download_name 或 mimetype,则从 path_or_file 中获取文件名和 MIME 类型 @@ -61,6 +144,16 @@ def send_attachment( def download_file(db_path, local_path): + """ + 从db_path下载文件到local_path。 + + 参数: + db_path (str): 数据库文件路径。 + local_path (str): 本地文件路径。 + + 返回: + str: 本地文件路径。 + """ with open(local_path, 'wb') as f: with open_file(db_path, 'rb') as r: f.write(r.read()) @@ -68,5 +161,28 @@ def download_file(db_path, local_path): def isLocalPath(path: str) -> bool: - return isinstance(determine_strategy(path), LocalAttachment) + """ + 判断路径是否为本地路径。 + + 参数: + path (str): 文件或目录路径。 + + 返回: + bool: 如果是本地路径返回True,否则返回False。 + """ + strategy = determine_strategy(path) + return isinstance(strategy, type(LocalAttachment())) + + +def getsize(path: str): + """ + 获取文件大小 + + 参数: + path (str): 文件路径 + + 返回: + int: 文件大小 + """ + return determine_strategy(path).getsize(path) diff --git a/pywxdump/file/ConfigurableAttachment.py b/pywxdump/file/ConfigurableAttachment.py new file mode 100644 index 0000000..b4be326 --- /dev/null +++ b/pywxdump/file/ConfigurableAttachment.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod + +from pywxdump.common.config.oss_config import storage_config +from pywxdump.file.Attachment import Attachment + + +class ConfigurableAttachment(ABC, Attachment): + @classmethod + @abstractmethod + def load_config(cls, config: storage_config) -> Attachment: + """设置配置""" + pass diff --git a/pywxdump/file/LocalAttachment.py b/pywxdump/file/LocalAttachment.py index 3705b45..4fc97cd 100644 --- a/pywxdump/file/LocalAttachment.py +++ b/pywxdump/file/LocalAttachment.py @@ -3,37 +3,139 @@ import os import sys from typing import IO +from pywxdump.file.Attachment import Attachment -class LocalAttachment: + +def singleton(cls): + instances = {} + + def create_instance(*args, **kwargs): + if cls not in instances: + instances[cls] = cls(*args, **kwargs) + return instances[cls] + + return create_instance + + +@singleton +class LocalAttachment(Attachment): def open(self, path, mode) -> IO: + """ + 打开一个文件并返回文件对象。 + + 参数: + path (str): 文件路径。 + mode (str): 打开文件的模式。 + + 返回: + IO: 文件对象。 + """ path = self.dealLocalPath(path) return open(path, mode) + def remove(self, path: str) -> bool: + """ + 删除文件 + + 参数: + path (str): 文件路径 + + 返回: + bool: 是否删除成功 + """ + path = self.dealLocalPath(path) + if not self.exists(path): + raise FileNotFoundError(f"File not found: {path}") + + if self.isdir(path): + raise ValueError(f"Path is not a file: {path}") + + os.remove(path) + return True + def exists(self, path) -> bool: + """ + 检查文件或目录是否存在。 + + 参数: + path (str): 文件或目录路径。 + + 返回: + bool: 如果存在返回True,否则返回False。 + """ path = self.dealLocalPath(path) return os.path.exists(path) def makedirs(self, path) -> bool: + """ + 创建目录,包括所有中间目录。 + + 参数: + path (str): 目录路径。 + + 返回: + bool: 总是返回True。 + """ path = self.dealLocalPath(path) os.makedirs(path) return True @classmethod - def join(cls, __a: str, *paths: str) -> str: - return os.path.join(__a, *paths) + def join(cls, path: str, *paths: str) -> str: + """ + 连接一个或多个路径组件。 + + 参数: + path (str): 第一个路径组件。 + *paths (str): 其他路径组件。 + + 返回: + str: 连接后的路径。 + """ + # 使用os.path.join连接路径 + return os.path.join(path, *paths) @classmethod def dirname(cls, path: str) -> str: + """ + 获取路径的目录名。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 目录名。 + """ + # 获取路径的目录名 return os.path.dirname(path) @classmethod def basename(cls, path: str) -> str: + """ + 获取路径的基本名(文件名)。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 基本名(文件名)。 + """ + # 获取路径的基本名 return os.path.basename(path) def dealLocalPath(self, path: str) -> str: - # 获取当前系统的地址分隔符 - # 将path中的 /替换为当前系统的分隔符 + """ + 处理本地路径,替换路径中的分隔符,并根据操作系统进行特殊处理。 + + 参数: + path (str): 文件路径。 + + 返回: + str: 处理后的路径。 + """ + # 获取当前系统的路径分隔符 + # 将path中的 / 替换为当前系统的分隔符 path = path.replace('/', os.sep) if sys.platform == "win32": # 如果是windows系统,且路径长度超过260个字符 @@ -44,3 +146,52 @@ class LocalAttachment: return path else: return path + + def isdir(self, path: str) -> bool: + + """ + 判断是否为目录 + + 参数: + path (str): 文件路径 + + 返回: + bool: 是否为目录 + """ + # 判断路径是否为目录 + return os.path.isdir(path) + + def getsize(self, path) -> int: + """ + 获取文件大小 + + 参数: + path (str): 文件路径 + + 返回: + int: 文件大小 + """ + if not self.exists(path): + raise FileNotFoundError(f"File not found: {path}") + + if os.path.isfile(path): + return os.path.getsize(path) + else: + return self._get_dir_size(path) + + def _get_dir_size(self, path): + """ + 计算目录大小 + + 参数: + path (str): 目录路径 + + 返回: + int: 目录大小 + """ + total_size = 0 + for firePath, surnames, filenames in os.walk(path): + for f in filenames: + fp = self.join(firePath, f) + total_size += os.path.getsize(fp) + return total_size diff --git a/pywxdump/file/S3Attachment.py b/pywxdump/file/S3Attachment.py index d117981..e121f7f 100644 --- a/pywxdump/file/S3Attachment.py +++ b/pywxdump/file/S3Attachment.py @@ -1,89 +1,165 @@ # 对象存储文件处理类(示例:假设是 AWS S3) import os from typing import IO -from urllib.parse import urlparse +from urllib.parse import urlparse, urljoin from botocore.exceptions import ClientError from smart_open import open import boto3 from botocore.client import Config -class S3Attachment: +from pywxdump.common.config.oss_config import storage_config +from pywxdump.file.Attachment import Attachment +from pywxdump.file.ConfigurableAttachment import ConfigurableAttachment - def __init__(self): - # 腾讯云 COS 配置 - self.cos_endpoint = "https://cos..myqcloud.com" # 替换 为你的 COS 区域,例如 ap-shanghai - self.access_key_id = "SecretId" # 替换为你的腾讯云 SecretId - self.secret_access_key = "SecretKey" # 替换为你的腾讯云 SecretKey + +class S3Attachment(ConfigurableAttachment): + + def __init__(self, s3_config: storage_config): + # S3 配置 + self.s3_config = s3_config + # 校验配置 + s3_config.validate_config() # 创建 S3 客户端 self.s3_client = boto3.client( 's3', - endpoint_url=self.cos_endpoint, - aws_access_key_id=self.access_key_id, - aws_secret_access_key=self.secret_access_key, + endpoint_url=s3_config.endpoint_url, + aws_access_key_id=s3_config.access_key, + aws_secret_access_key=s3_config.secret_key, config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'}) ) - def exists(self, path) -> bool: - bucket_name, path = self.dealS3Url(path) - # 检查是否为目录 - if path.endswith('/'): - # 尝试列出该路径下的对象 - try: - response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1) - if 'Contents' in response: - return True - else: - return False - except ClientError as e: - print(f"Error: {e}") - return False - else: - # 检查是否为文件 - try: - self.s3_client.head_object(Bucket=bucket_name, Key=path) - return True - except ClientError as e: - if e.response['Error']['Code'] == '404': - return False - else: - print(f"Error: {e}") - return False + @classmethod + def load_config(cls, config: storage_config) -> Attachment: + return cls(config) - def makedirs(self, path) -> bool: - if not self.exists(path): - bucket_name, path = self.dealS3Url(path) + def exists(self, s3_url) -> bool: + """ + 检查对象是否存在 + + 参数: + s3_url (str): 对象路径 + + 返回: + bool: 是否存在 + """ + bucket_name, path = self.dealS3Url(s3_url) + # 尝试列出该路径下的对象 + try: + response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1) + if 'Contents' in response: + return True + else: + return False + except ClientError as e: + print(f"Error: {e}") + return False + + def makedirs(self, s3_url) -> bool: + """ + 创建目录 + + 参数: + s3_url (str): 目录路径 + + 返回: + bool: 是否创建成功 + """ + if not self.exists(s3_url): + bucket_name, path = self.dealS3Url(s3_url) self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/') return True - def open(self, path, mode) -> IO: - self.dealS3Url(path) - return open(uri=path, mode=mode, transport_params={'client': self.s3_client}) + def open(self, s3_url, mode) -> IO: + """ + 打开文件 + + 参数: + s3_url (str): 文件路径 + mode (str): 打开模式 + + 返回: + IO: 文件对象 + """ + return open(uri=s3_url, mode=mode, transport_params={'client': self.s3_client}) + + def remove(self, s3_url: str) -> bool: + """ + 删除文件 + + 参数: + s3_url (str): 文件路径 + + 返回: + bool: 是否删除成功 + """ + + if not self.exists(s3_url): + raise FileNotFoundError(f"File not found: {s3_url}") + + if self.isdir(s3_url): + raise ValueError(f"Path is not a file: {s3_url}") + + bucket_name, path = self.dealS3Url(s3_url) + + self.s3_client.delete_object(Bucket=bucket_name, Key=path) + return True @classmethod - def join(cls, __a: str, *paths: str) -> str: - return os.path.join(__a, *paths) + def join(cls, s3_url: str, *paths: str) -> str: + """ + 连接路径 + + 参数: + s3_url (str): 路径 + *paths (str): 路径 + + 返回: + str: 连接后的路径 + """ + # 使用os.path.join连接路径 + path = os.path.join(s3_url, *paths) + # 将所有反斜杠替换为正斜杠 + return path.replace('\\', '/') @classmethod - def dirname(cls, path: str) -> str: - return os.path.dirname(path) + def dirname(cls, s3_url: str) -> str: + """ + 返回路径的目录部分 + + 参数: + s3_url (str): 路径 + + 返回: + str: 路径的目录部分 + """ + return os.path.dirname(s3_url) @classmethod - def basename(cls, path: str) -> str: - return os.path.basename(path) + def basename(cls, s3_url: str) -> str: + """ + 返回路径的最后一个元素 - def dealS3Url(self, path: str) -> object: + 参数: + s3_url (str): 路径 + + 返回: + str: 路径的最后一个元素 + """ + return os.path.basename(s3_url) + + def dealS3Url(self, s3_url: str) -> object: """ 解析 S3 URL 并返回存储桶名称和路径 参数: - path (str): S3 URL + s3_url (str): S3 URL 返回: tuple: 包含存储桶名称和路径的元组 """ - parsed_url = urlparse(path) + parsed_url = urlparse(s3_url) # 确保URL是S3 URL if parsed_url.scheme != 's3': @@ -94,3 +170,75 @@ class S3Attachment: return bucket_name, s3_path + def isdir(self, s3_url: str) -> bool: + + """ + 判断是否为目录 + + 参数: + s3_url (str): 文件路径 + + 返回: + bool: 是否为目录 + """ + + # 确保目录路径以'/'结尾 + if not s3_url.endswith('/'): + s3_url += '/' + + bucket_name, path = self.dealS3Url(s3_url) + # 列出以该 key 为前缀的对象 + response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1) + + if 'Contents' in response: + # 存在对象,判断是否为目录 + if response['Contents'][0]['Key'] == path or not path.endswith('/'): + return False + else: + return True + else: + return False + + def getsize(self, s3_url) -> int: + """ + 获取文件大小 + + 参数: + path (str): 文件路径 + + 返回: + int: 文件大小 + """ + if not self.exists(s3_url): + raise FileNotFoundError(f"File not found: {s3_url}") + if self.isdir(s3_url): + return self._get_size_of_directory(s3_url) + else: + bucket_name, path = self.dealS3Url(s3_url) + response = self.s3_client.head_object(Bucket=bucket_name, Key=path) + return response['ContentLength'] + + def _get_size_of_directory(self, s3_url): + """ + 获取目录大小 + + 参数: + s3_url (str): 目录路径 + + 返回: + int: 目录大小 + """ + bucket_name, path = self.dealS3Url(s3_url) + total_size = 0 + + # 确保目录路径以'/'结尾 + if not path.endswith('/'): + path += '/' + + # 列出指定目录中的对象 + response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path) + if 'Contents' in response: + for obj in response['Contents']: + total_size += obj['Size'] + + return total_size diff --git a/pywxdump/file/__init__.py b/pywxdump/file/__init__.py new file mode 100644 index 0000000..36213f2 --- /dev/null +++ b/pywxdump/file/__init__.py @@ -0,0 +1,4 @@ + + +if __name__ == '__main__': + pass diff --git a/pywxdump/server.py b/pywxdump/server.py index 87fd949..8dae044 100644 --- a/pywxdump/server.py +++ b/pywxdump/server.py @@ -10,9 +10,11 @@ import subprocess import sys import time +from pywxdump.common.config.oss_config.s3_config import S3Config +from pywxdump.common.config.server_config import ServerConfig -def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online=False, debug=False, - isopenBrowser=True): + +def start_falsk(server_config: ServerConfig): """ 启动flask :param merge_path: 合并后的数据库路径 @@ -38,13 +40,14 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online import logging # 检查端口是否被占用 - if online: + if server_config.online: host = '0.0.0.0' else: host = "127.0.0.1" app = Flask(__name__, template_folder='./ui/web', static_folder='./ui/web/assets/', static_url_path='/assets/') - + app.config['ENV'] = 'development' + app.config['DEBUG'] = True # 设置超时时间为 1000 秒 app.config['TIMEOUT'] = 1000 app.secret_key = 'secret_key' @@ -67,18 +70,20 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online g.tmp_path = tmp_path # 临时文件夹,用于存放图片等 g.sf = session_file # 用于存放各种基础信息 - if merge_path: save_session(session_file, "test", "merge_path", merge_path) - if wx_path: save_session(session_file, "test", "wx_path", wx_path) - if key: save_session(session_file, "test", "key", key) - if my_wxid: save_session(session_file, "test", "my_wxid", my_wxid) + wxid = server_config.my_wxid if server_config.my_wxid else "test" + if server_config.merge_path: save_session(session_file, wxid, "merge_path", server_config.merge_path) + if server_config.wx_path: save_session(session_file, wxid, "wx_path", server_config.wx_path) + if server_config.key: save_session(session_file, wxid, "key", server_config.key) + if server_config.my_wxid: save_session(session_file, wxid, "my_wxid", server_config.my_wxid) + if server_config.oss_config: save_session(session_file, wxid, "oss_config", server_config.oss_config_to_json()) if not os.path.exists(session_file): - save_session(session_file, "test", "last", my_wxid) + save_session(session_file, wxid, "last", server_config.my_wxid) app.register_blueprint(api) - if isopenBrowser: + if server_config.is_open_browser: try: # 自动打开浏览器 - url = f"http://127.0.0.1:{port}/" + url = f"http://127.0.0.1:{server_config.port}/" # 根据操作系统使用不同的命令打开默认浏览器 if sys.platform.startswith('darwin'): # macOS subprocess.call(['open', url]) @@ -100,13 +105,13 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online return True return False - if is_port_in_use(host, port): - print(f"Port {port} is already in use. Choose a different port.") + if is_port_in_use(host, server_config.port): + print(f"Port {server_config.port} is already in use. Choose a different port.") input("Press Enter to exit...") else: time.sleep(1) print("[+] 请使用浏览器访问 http://127.0.0.1:5000/ 查看聊天记录") - app.run(host=host, port=port, debug=debug) + app.run(host=host, port=server_config.port, debug=server_config.debug) if __name__ == '__main__': @@ -114,6 +119,16 @@ if __name__ == '__main__': wx_path = r"****" my_wxid = "****" + server_config = ServerConfig.builder() + server_config.merge_path("s3://*********-1256220500/*********/merge_all.db") + server_config.wx_path("s3://*********-1256220500/*********") + server_config.my_wxid("test") + server_config.port(9000) + server_config.online(True) + server_config.is_open_browser(False) + + s3Config = S3Config("AKIDaAjA*********I1kR4gFdv67v", "wlT2ldSBk*********Qh4fEev47", + "https://cos.ap-nanjing.myqcloud.com") + server_config.oss_config(s3Config) + start_falsk(server_config.build()) - start_falsk(merge_path=merge_path, wx_path=wx_path, my_wxid=my_wxid, - port=5000, online=False, debug=False, isopenBrowser=False) diff --git a/pywxdump/wx_info/get_wx_info.py b/pywxdump/wx_info/get_wx_info.py index 26c3d78..765b2ed 100644 --- a/pywxdump/wx_info/get_wx_info.py +++ b/pywxdump/wx_info/get_wx_info.py @@ -395,7 +395,7 @@ def get_wechat_db(require_list: Union[List[str], str] = "all", msg_dir: str = No def get_core_db(wx_path: str, db_type: list = None) -> [str]: """ 获取聊天消息核心数据库路径 - :param wx_path: 微信文件夹路径 eg:C:\*****\WeChat Files\wxid******* + :param wx_path: 微信文件夹路径 eg:C:\\*****\\WeChat Files\\wxid******* :param db_type: 数据库类型 eg: ["MSG", "MediaMSG", "MicroMsg"],三个中选择一个或多个 :return: 返回数据库路径 eg:["",""] """ diff --git a/pywxdump/wx_info/merge_db.py b/pywxdump/wx_info/merge_db.py index 4da0f57..40c214b 100644 --- a/pywxdump/wx_info/merge_db.py +++ b/pywxdump/wx_info/merge_db.py @@ -304,7 +304,7 @@ def decrypt_merge(wx_path, key, outpath="", CreateTime: int = 0, endCreateTime: bool, str): """ 解密合并数据库 msg.db, microMsg.db, media.db,注意:会删除原数据库 - :param wx_path: 微信路径 eg: C:\*******\WeChat Files\wxid_********* + :param wx_path: 微信路径 eg: C:\\*******\\WeChat Files\\wxid_********* :param key: 解密密钥 :return: (true,解密后的数据库路径) or (false,错误信息) """ @@ -430,7 +430,7 @@ def all_merge_real_time_db(key, wx_path, merge_path): 注:这是全量合并,会有可能产生重复数据,需要自行去重 :param key: 解密密钥 :param wx_path: 微信路径 - :param merge_path: 合并后的数据库路径 eg: C:\*******\WeChat Files\wxid_*********\merge.db + :param merge_path: 合并后的数据库路径 eg: C:\\*******\\WeChat Files\\wxid_*********\\merge.db :return: """ if not merge_path or not key or not wx_path or not wx_path: diff --git a/requirements.txt b/requirements.txt index 0d3c38f..ec782af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ lxml flask_cors pandas smart_open[s3] -boto3 \ No newline at end of file +boto3 +botocore \ No newline at end of file diff --git a/tests/file/TestLocalAttachment.py b/tests/file/TestLocalAttachment.py new file mode 100644 index 0000000..eca71c5 --- /dev/null +++ b/tests/file/TestLocalAttachment.py @@ -0,0 +1,63 @@ +import os +import unittest +from unittest.mock import patch +from pywxdump.file.LocalAttachment import LocalAttachment +import tempfile + + +class TestLocalAttachment(unittest.TestCase): + + def setUp(self): + self.attachment = LocalAttachment() + self.test_file_path = self.attachment.join(tempfile.gettempdir(),"test.txt") + self.test_dir_path = self.attachment.join(tempfile.gettempdir(),"test_dir") + + def tearDown(self): + if os.path.exists(self.test_file_path): + os.remove(self.test_file_path) + if os.path.exists(self.test_dir_path): + for dirpath, dirnames, filenames in os.walk(self.test_dir_path): + for f in filenames: + fp = os.path.join(dirpath, f) + os.remove(fp) + os.rmdir(self.test_dir_path) + + def test_remove_existing_file(self): + with open(self.test_file_path, 'w') as f: + f.write("test") + self.assertTrue(self.attachment.remove(self.test_file_path)) + self.assertFalse(os.path.exists(self.test_file_path)) + + def test_remove_non_existing_file(self): + with self.assertRaises(FileNotFoundError): + self.attachment.remove(self.test_file_path) + + @patch('os.remove') + def test_remove_os_error(self, mock_remove): + mock_remove.side_effect = OSError + with self.assertRaises(OSError): + self.attachment.remove(self.test_file_path) + + def test_getsize_existing_file(self): + # 清理测试环境 + self.tearDown() + with open(self.test_file_path, "w") as f: + f.write("Hello, World!") + self.assertEqual(self.attachment.getsize(self.test_file_path), 13) + + def test_getsize_non_existing_file(self): + with self.assertRaises(FileNotFoundError): + self.attachment.getsize('non_existing_file') + + def test_getsize_existing_folder(self): + # 清理测试环境 + self.tearDown() + os.mkdir(self.test_dir_path) + with open(os.path.join(self.test_dir_path, "file1.txt"), "w") as f: + f.write("Hello, World!") + with open(os.path.join(self.test_dir_path, "file2.txt"), "w") as f: + f.write("Hello, World!") + self.assertEqual(self.attachment.getsize(self.test_dir_path), 26) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/file/TestS3Attachment.py b/tests/file/TestS3Attachment.py new file mode 100644 index 0000000..015987c --- /dev/null +++ b/tests/file/TestS3Attachment.py @@ -0,0 +1,71 @@ +import unittest +from unittest.mock import patch, MagicMock +from botocore.exceptions import ClientError +import boto3 +from pywxdump.file.S3Attachment import S3Attachment + + +class TestS3Attachment(unittest.TestCase): + + @patch('boto3.client') + def setUp(self, mock_client): + self.mock_client = MagicMock() + mock_client.return_value = self.mock_client + s3_config = MagicMock() + self.attachment = S3Attachment(s3_config) + self.test_s3_url = "s3://test_bucket/test_file" + self.test_s3_dir = "s3://test_bucket/test_folder/" + + @patch.object(S3Attachment, 'exists') + @patch.object(S3Attachment, 'isFolder') + @patch('boto3.client') + def test_removal_of_existing_file(self, mock_client, mock_isFolder, mock_exists): + mock_exists.return_value = True + mock_isFolder.return_value = False + mock_client.return_value = MagicMock() + self.assertTrue(self.attachment.remove(self.test_s3_url)) + + @patch.object(S3Attachment, 'exists') + def test_removal_of_non_existing_file(self, mock_exists): + mock_exists.return_value = False + with self.assertRaises(FileNotFoundError): + self.attachment.remove(self.test_s3_url) + + @patch.object(S3Attachment, 'exists') + @patch.object(S3Attachment, 'isFolder') + def test_removal_of_folder_instead_of_file(self, mock_isFolder, mock_exists): + mock_exists.return_value = True + mock_isFolder.return_value = True + with self.assertRaises(ValueError): + self.attachment.remove(self.test_s3_url) + + @patch.object(S3Attachment, 'exists') + @patch.object(S3Attachment, 'isdir') + def test_removal_with_s3_error(self, mock_isdir, mock_exists): + mock_exists.return_value = True + mock_isdir.return_value = False + # 模拟 ClientError + error_response = {'Error': {'Code': 'InvalidRequest', 'Message': 'Some error message'}} + self.mock_client.delete_object.side_effect = ClientError(error_response, 'delete_object') + with self.assertRaises(ClientError): + self.attachment.remove(self.test_s3_url) + + @patch.object(S3Attachment, 'exists') + def test_getsize_existing_file(self, mock_exists): + mock_exists.return_value = True + self.mock_client.head_object.return_value = {'ContentLength': 100} + self.assertEqual(self.attachment.getsize(self.test_s3_url), 100) + + def test_getsize_non_existing_file(self): + with self.assertRaises(FileNotFoundError): + self.attachment.getsize('non_existing_file') + + @patch.object(S3Attachment, 'isdir') + def test_getsize_existing_folder(self, mock_isdir): + mock_isdir.return_value = True + self.mock_client.list_objects_v2.return_value = {'Contents': [{'Size': 100}, {'Size': 300}]} + self.assertEqual(self.attachment.getsize(self.test_s3_dir), 400) + + +if __name__ == '__main__': + unittest.main()