1.增加获取文件大小、文件删除方法,并补充注释

2.增加对象存储的配置实现
This commit is contained in:
cllcode 2024-07-13 21:56:42 +08:00
parent 2b3c3f9d1f
commit 4de1c26f25
22 changed files with 1222 additions and 150 deletions

View File

@ -5,14 +5,14 @@
# Author: xaoyaoo # Author: xaoyaoo
# Date: 2024/01/02 # Date: 2024/01/02
# ------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import base64
import json
import logging import logging
import os
import re
import time import time
import shutil import shutil
import sys import sys
from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory
from pywxdump.common.config.oss_config_manager import OSSConfigManager
pythoncom = __import__('pythoncom') if sys.platform == "win32" else None pythoncom = __import__('pythoncom') if sys.platform == "win32" else None
import pywxdump import pywxdump
from pywxdump.file import AttachmentContext from pywxdump.file import AttachmentContext
@ -55,8 +55,8 @@ def init_last():
是否初始化 是否初始化
:return: :return:
""" """
my_wxid = request.json.get("my_wxid", "") my_wxid_dict = request.json.get("my_wxid", {})
my_wxid = my_wxid.strip().strip("'").strip('"') if isinstance(my_wxid, str) else "" my_wxid = my_wxid_dict.get("wxid", "")
if not my_wxid: if not my_wxid:
my_wxid = read_session(g.sf, "test", "last") my_wxid = read_session(g.sf, "test", "last")
if my_wxid: if my_wxid:
@ -64,6 +64,9 @@ def init_last():
merge_path = read_session(g.sf, my_wxid, "merge_path") merge_path = read_session(g.sf, my_wxid, "merge_path")
wx_path = read_session(g.sf, my_wxid, "wx_path") wx_path = read_session(g.sf, my_wxid, "wx_path")
key = read_session(g.sf, my_wxid, "key") key = read_session(g.sf, my_wxid, "key")
# 如果有oss_config则设置对象存储配置
oss_config = read_session(g.sf, my_wxid, "oss_config")
ossConfig(oss_config)
rdata = { rdata = {
"merge_path": merge_path, "merge_path": merge_path,
"wx_path": wx_path, "wx_path": wx_path,
@ -96,11 +99,15 @@ def init_key():
return ReJson(1002, body=f"my_wxid is required: {my_wxid}") return ReJson(1002, body=f"my_wxid is required: {my_wxid}")
old_merge_save_path = read_session(g.sf, my_wxid, "merge_path") old_merge_save_path = read_session(g.sf, my_wxid, "merge_path")
if isinstance(old_merge_save_path, str) and old_merge_save_path and AttachmentContext.exists(old_merge_save_path): # 如果有oss_config则设置对象存储配置
pmsg = ParsingMSG(old_merge_save_path) oss_config = read_session(g.sf, my_wxid, "oss_config")
pmsg.close_all_connection() ossConfig(oss_config)
# 如果存在旧地连接则关闭连接
if isinstance(old_merge_save_path, str) and old_merge_save_path:
ParsingMSG.terminate_connection(old_merge_save_path)
out_path = AttachmentContext.join(g.tmp_path, "decrypted", my_wxid) if my_wxid else AttachmentContext.join(g.tmp_path, "decrypted") out_path = AttachmentContext.join(g.tmp_path, "decrypted", my_wxid) if my_wxid else AttachmentContext.join(
g.tmp_path, "decrypted")
# 检查文件夹中文件是否被占用 # 检查文件夹中文件是否被占用
if AttachmentContext.exists(out_path): if AttachmentContext.exists(out_path):
try: try:
@ -115,7 +122,7 @@ def init_key():
if code: if code:
# 移动merge_save_path到g.tmp_path/my_wxid # 移动merge_save_path到g.tmp_path/my_wxid
if not AttachmentContext.exists(AttachmentContext.join(g.tmp_path, my_wxid)): if not AttachmentContext.exists(AttachmentContext.join(g.tmp_path, my_wxid)):
os.makedirs(AttachmentContext.join(g.tmp_path, my_wxid)) AttachmentContext.makedirs(AttachmentContext.join(g.tmp_path, my_wxid))
merge_save_path_new = AttachmentContext.join(g.tmp_path, my_wxid, "merge_all.db") merge_save_path_new = AttachmentContext.join(g.tmp_path, my_wxid, "merge_all.db")
shutil.move(merge_save_path, str(merge_save_path_new)) shutil.move(merge_save_path, str(merge_save_path_new))
@ -144,6 +151,19 @@ def init_key():
return ReJson(2001, body=merge_save_path) return ReJson(2001, body=merge_save_path)
def ossConfig(oss_config: str):
"""
设置对象存储配置
:param oss_config: 对象存储配置
:return: None
"""
if oss_config:
storageConfig = StorageConfigFactory.create(oss_config)
OSSConfigManager().load_config(storageConfig)
@api.route('/api/init_nokey', methods=["GET", 'POST']) @api.route('/api/init_nokey', methods=["GET", 'POST'])
@error9999 @error9999
def init_nokey(): def init_nokey():
@ -154,6 +174,9 @@ def init_nokey():
merge_path = request.json.get("merge_path", "").strip().strip("'").strip('"') merge_path = request.json.get("merge_path", "").strip().strip("'").strip('"')
wx_path = request.json.get("wx_path", "").strip().strip("'").strip('"') wx_path = request.json.get("wx_path", "").strip().strip("'").strip('"')
my_wxid = request.json.get("my_wxid", "").strip().strip("'").strip('"') my_wxid = request.json.get("my_wxid", "").strip().strip("'").strip('"')
# 如果有oss_config则设置对象存储配置
oss_config = request.json.get("oss_config", "").strip().strip("'").strip('"')
ossConfig(oss_config)
if not wx_path: if not wx_path:
return ReJson(1002, body=f"wx_path is required: {wx_path}") return ReJson(1002, body=f"wx_path is required: {wx_path}")
@ -170,6 +193,7 @@ def init_nokey():
save_session(g.sf, my_wxid, "wx_path", wx_path) save_session(g.sf, my_wxid, "wx_path", wx_path)
save_session(g.sf, my_wxid, "key", key) save_session(g.sf, my_wxid, "key", key)
save_session(g.sf, my_wxid, "my_wxid", my_wxid) save_session(g.sf, my_wxid, "my_wxid", my_wxid)
save_session(g.sf, my_wxid, "oss_config", oss_config)
save_session(g.sf, "test", "last", my_wxid) save_session(g.sf, "test", "last", my_wxid)
rdata = { rdata = {
"merge_path": merge_path, "merge_path": merge_path,
@ -181,6 +205,25 @@ def init_nokey():
return ReJson(0, rdata) return ReJson(0, rdata)
# 查询支持的对象存储,及配置
@api.route('/api/check_storage_type', methods=["GET", 'POST'])
@error9999
def check_storage_type():
path = request.json.get("path", "").strip().strip("'").strip('"')
if not path:
return ReJson(1002, body=f"path is required: {path}")
selectStorageConfig = None
for key, storageConfig in StorageConfigFactory.registry.items():
if storageConfig.isSupported(path):
selectStorageConfig = storageConfig
if selectStorageConfig:
return ReJson(0, {"is_supported": True, "storage_type": selectStorageConfig.type(),
"config_items": selectStorageConfig.describe()})
else:
return ReJson(0, {"is_supported": False})
# END 以上为初始化相关 *************************************************************************************************** # END 以上为初始化相关 ***************************************************************************************************
@ -338,7 +381,7 @@ def get_imgsrc(imgsrc):
img_tmp_path = AttachmentContext.join(g.tmp_path, my_wxid, "imgsrc") img_tmp_path = AttachmentContext.join(g.tmp_path, my_wxid, "imgsrc")
if not AttachmentContext.exists(img_tmp_path): if not AttachmentContext.exists(img_tmp_path):
os.makedirs(img_tmp_path) AttachmentContext.makedirs(img_tmp_path)
file_name = imgsrc.replace("http://", "").replace("https://", "").replace("/", "_").replace("?", "_") file_name = imgsrc.replace("http://", "").replace("https://", "").replace("/", "_").replace("?", "_")
file_name = file_name + ".jpg" file_name = file_name + ".jpg"
# 如果文件名过长,则将文件明分为目录和文件名 # 如果文件名过长,则将文件明分为目录和文件名
@ -434,10 +477,10 @@ def get_video(videoPath):
# 复制文件到临时文件夹 # 复制文件到临时文件夹
video_save_path = AttachmentContext.join(video_tmp_path, videoPath) video_save_path = AttachmentContext.join(video_tmp_path, videoPath)
if not AttachmentContext.exists(AttachmentContext.dirname(video_save_path)): if not AttachmentContext.exists(AttachmentContext.dirname(video_save_path)):
os.makedirs(AttachmentContext.dirname(video_save_path)) AttachmentContext.makedirs(AttachmentContext.dirname(video_save_path))
if AttachmentContext.exists(video_save_path): if AttachmentContext.exists(video_save_path):
return AttachmentContext.send_attachment(video_save_path) return AttachmentContext.send_attachment(video_save_path)
AttachmentContext.download_file(original_img_path,video_save_path) AttachmentContext.download_file(original_img_path, video_save_path)
return AttachmentContext.send_attachment(original_img_path) return AttachmentContext.send_attachment(original_img_path)
@ -457,7 +500,7 @@ def get_audio(savePath):
# 判断savePath路径的文件夹是否存在 # 判断savePath路径的文件夹是否存在
if not AttachmentContext.exists(AttachmentContext.dirname(savePath)): if not AttachmentContext.exists(AttachmentContext.dirname(savePath)):
os.makedirs(AttachmentContext.dirname(savePath)) AttachmentContext.makedirs(AttachmentContext.dirname(savePath))
parsing_media_msg = ParsingMediaMSG(merge_path) parsing_media_msg = ParsingMediaMSG(merge_path)
wave_data = parsing_media_msg.get_audio(MsgSvrID, is_play=False, is_wave=True, save_path=savePath, rate=24000) wave_data = parsing_media_msg.get_audio(MsgSvrID, is_play=False, is_wave=True, save_path=savePath, rate=24000)
@ -484,8 +527,8 @@ def get_file_info():
all_file_path = AttachmentContext.join(wx_path, file_path) all_file_path = AttachmentContext.join(wx_path, file_path)
if not AttachmentContext.exists(all_file_path): if not AttachmentContext.exists(all_file_path):
return ReJson(5002) return ReJson(5002)
file_name = os.path.basename(all_file_path) file_name = AttachmentContext.basename(all_file_path)
file_size = os.path.getsize(all_file_path) file_size = AttachmentContext.getsize(all_file_path)
return ReJson(0, {"file_name": file_name, "file_size": str(file_size)}) return ReJson(0, {"file_name": file_name, "file_size": str(file_size)})
@ -528,12 +571,12 @@ def get_export_endb():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "endb") outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "endb")
if not AttachmentContext.exists(outpath): if not AttachmentContext.exists(outpath):
os.makedirs(outpath) AttachmentContext.makedirs(outpath)
for wxdb in wxdbpaths: for wxdb in wxdbpaths:
# 复制wxdb->outpath, os.path.basename(wxdb) # 复制wxdb->outpath, os.path.basename(wxdb)
assert isinstance(outpath, str) # 为了解决pycharm的警告, 无实际意义 assert isinstance(outpath, str) # 为了解决pycharm的警告, 无实际意义
shutil.copy(wxdb, AttachmentContext.join(outpath, os.path.basename(wxdb))) shutil.copy(wxdb, AttachmentContext.join(outpath, AttachmentContext.basename(wxdb)))
return ReJson(0, body=outpath) return ReJson(0, body=outpath)
@ -558,7 +601,7 @@ def get_export_dedb():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "dedb") outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "dedb")
if not AttachmentContext.exists(outpath): if not AttachmentContext.exists(outpath):
os.makedirs(outpath) AttachmentContext.makedirs(outpath)
code, merge_save_path = decrypt_merge(wx_path=wx_path, key=key, outpath=outpath) code, merge_save_path = decrypt_merge(wx_path=wx_path, key=key, outpath=outpath)
time.sleep(1) time.sleep(1)
@ -589,7 +632,7 @@ def get_export_csv():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "csv", wxid) outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "csv", wxid)
if not AttachmentContext.exists(outpath): if not AttachmentContext.exists(outpath):
os.makedirs(outpath) AttachmentContext.makedirs(outpath)
code, ret = export_csv(wxid, outpath, read_session(g.sf, my_wxid, "merge_path")) code, ret = export_csv(wxid, outpath, read_session(g.sf, my_wxid, "merge_path"))
if code: if code:
@ -613,7 +656,7 @@ def get_export_json():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "json", wxid) outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "json", wxid)
if not AttachmentContext.exists(outpath): if not AttachmentContext.exists(outpath):
os.makedirs(outpath) AttachmentContext.makedirs(outpath)
code, ret = export_json(wxid, outpath, read_session(g.sf, my_wxid, "merge_path")) code, ret = export_json(wxid, outpath, read_session(g.sf, my_wxid, "merge_path"))
if code: if code:

View File

@ -6,11 +6,18 @@
# Date: 2023/10/14 # Date: 2023/10/14
# ------------------------------------------------------------------------------- # -------------------------------------------------------------------------------
import argparse import argparse
import json
import os
import sys import sys
import time
from pywxdump import * from pywxdump import *
import pywxdump import pywxdump
from pywxdump.common.config.oss_config.storage_config import DescriptionBuilder
from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory
from pywxdump.common.config.oss_config_manager import OSSConfigManager
from pywxdump.common.config.server_config import ServerConfig
from pywxdump.common.constants import VERSION_LIST_PATH
from pywxdump.file import AttachmentContext
wxdump_ascii = r""" wxdump_ascii = r"""
@ -264,31 +271,63 @@ class MainShowChatRecords(BaseSubMainClass):
default="", metavar="") default="", metavar="")
parser.add_argument("--online", action='store_true', help="(可选)是否在线查看(局域网查看)", required=False, parser.add_argument("--online", action='store_true', help="(可选)是否在线查看(局域网查看)", required=False,
default=False) default=False)
for key, storageConfig in StorageConfigFactory.registry.items():
storageConfigDescribe = storageConfig.describe()
for description in storageConfigDescribe:
parser.add_argument(
f"--{key}_{description[DescriptionBuilder.KEY]}",
help=description[DescriptionBuilder.PLACEHOLDER],
default=False, required=False, type=str, metavar=""
)
# parser.add_argument("-k", "--key", type=str, help="(可选)密钥", required=False, metavar="") # parser.add_argument("-k", "--key", type=str, help="(可选)密钥", required=False, metavar="")
return parser return parser
def run(self, args): def run(self, args):
print(f"[*] PyWxDump v{pywxdump.__version__}") print(f"[*] PyWxDump v{pywxdump.__version__}")
server_config = ServerConfig.builder()
server_config.merge_path(args.merge_path)
server_config.wx_path(args.wx_path)
server_config.my_wxid(args.my_wxid)
server_config.online(args.online)
server_config.port(9000)
for key, storageConfig in StorageConfigFactory.registry.items():
storageConfigDescribe = storageConfig.describe()
config = {}
for description in storageConfigDescribe:
value = getattr(args, f"{key}_{description[DescriptionBuilder.KEY]}")
if value:
config[description[DescriptionBuilder.KEY]] = value
if config:
oss_config = storageConfig.value_of(config)
server_config.oss_config(oss_config)
# 加载对象存储配置
OSSConfigManager().load_config(oss_config)
start_config = server_config.build()
# 参数检查
if not self._dbshow_parameter_check(start_config):
return
start_falsk(start_config)
def _dbshow_parameter_check(self, server_config) -> bool:
# (merge)和(msg_path,micro_path,media_path) 二选一 # (merge)和(msg_path,micro_path,media_path) 二选一
# if not args.merge_path and not (args.msg_path and args.micro_path and args.media_path): # if not args.merge_path and not (args.msg_path and args.micro_path and args.media_path):
# print("[-] 请输入数据库路径([merge_path] or [msg_path, micro_path, media_path]") # print("[-] 请输入数据库路径([merge_path] or [msg_path, micro_path, media_path]")
# return # return
# 目前仅能支持merge database # 目前仅能支持merge database
if not args.merge_path: if not server_config.merge_path:
print("[-] 请输入数据库路径([merge_path]") print("[-] 请输入数据库路径([merge_path]")
return return False
# 从命令行参数获取值 # 从命令行参数获取值
merge_path = args.merge_path if not AttachmentContext.exists(server_config.merge_path):
online = args.online
if not os.path.exists(merge_path):
print("[-] 输入数据库路径不存在") print("[-] 输入数据库路径不存在")
return return False
return True
start_falsk(merge_path=merge_path, wx_path=args.wx_path, key="", my_wxid=args.my_wxid, online=online)
class MainExportChatRecords(BaseSubMainClass): class MainExportChatRecords(BaseSubMainClass):
@ -338,7 +377,13 @@ class MainUi(BaseSubMainClass):
debug = args.debug debug = args.debug
isopenBrowser = args.isOpenBrowser isopenBrowser = args.isOpenBrowser
start_falsk(port=port, online=online, debug=debug, isopenBrowser=isopenBrowser) server_config = ServerConfig.builder()
server_config.is_open_browser(isopenBrowser)
server_config.debug(debug)
server_config.port(port)
server_config.online(online)
start_falsk(server_config.build())
class MainApi(BaseSubMainClass): class MainApi(BaseSubMainClass):
@ -359,7 +404,13 @@ class MainApi(BaseSubMainClass):
port = args.port port = args.port
debug = args.debug debug = args.debug
start_falsk(port=port, online=online, debug=debug, isopenBrowser=False) server_config = ServerConfig.builder()
server_config.debug(debug)
server_config.port(port)
server_config.is_open_browser(False)
server_config.online(online)
start_falsk(server_config.build())
def console_run(): def console_run():

View File

@ -0,0 +1,4 @@
if __name__ == '__main__':
pass

View File

@ -0,0 +1,58 @@
from pywxdump.common.config.oss_config.storage_config import StorageConfig, TYPE_KEY, DescriptionBuilder
from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory
S3 = "s3"
ENDPOINT_URL = "endpoint_url"
SECRET_KEY = "secret_key"
ACCESS_KEY = "access_key"
@StorageConfigFactory.register(S3)
class S3Config(StorageConfig):
def __init__(self, access_key, secret_key, endpoint_url):
self.access_key = access_key
self.secret_key = secret_key
self.endpoint_url = endpoint_url
@classmethod
def type(cls) -> str:
return S3
@classmethod
def describe(cls):
builder = DescriptionBuilder()
builder.add_description(ENDPOINT_URL, "https://cos.<your-regin>.myqcloud.com", ENDPOINT_URL)
builder.add_description(ACCESS_KEY, "腾讯云的SecretId", ACCESS_KEY)
builder.add_description(SECRET_KEY, "腾讯云的SecretKey", SECRET_KEY)
return builder.build()
def get_config(self):
return {
TYPE_KEY: S3,
ACCESS_KEY: self.access_key,
SECRET_KEY: self.secret_key,
ENDPOINT_URL: self.endpoint_url
}
def validate_config(self):
if not self.access_key or not self.secret_key or not self.endpoint_url:
raise ValueError("S3 configuration is not valid")
@classmethod
def value_of(cls, config: dict):
access_key = config.get(ACCESS_KEY)
secret_key = config.get(SECRET_KEY)
endpoint_url = config.get(ENDPOINT_URL)
s3_config = cls(access_key, secret_key, endpoint_url)
s3_config.validate_config()
return s3_config
@classmethod
def isSupported(cls, path: str) -> bool:
if not path:
return False
if path.startswith(f"{S3}://"):
return True

View File

@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
TYPE_KEY = "oss_type_key"
class StorageConfig(ABC):
@classmethod
@abstractmethod
def type(cls):
"""返回实现类的类型"""
pass
@classmethod
@abstractmethod
def describe(cls):
"""返回对象存储的描述"""
pass
@abstractmethod
def get_config(self):
"""返回存储配置的字典"""
pass
@abstractmethod
def validate_config(self):
"""验证配置是否合法"""
pass
@classmethod
@abstractmethod
def value_of(cls, config: dict):
pass
@classmethod
@abstractmethod
def isSupported(cls, path: str) -> bool:
pass
class DescriptionBuilder:
LABEL = "label"
PLACEHOLDER = "placeholder"
KEY = "key"
def __init__(self):
self.description = []
def add_description(self, label, placeholder, key):
self.description.append({
self.LABEL: label,
self.PLACEHOLDER: placeholder,
self.KEY: key
})
return self
def build(self):
return self.description

View File

@ -0,0 +1,25 @@
import json
from abc import ABC
from pywxdump.common.config.oss_config.storage_config import TYPE_KEY, StorageConfig
class StorageConfigFactory(ABC):
registry = {}
@classmethod
def register(cls, type_name):
def inner_wrapper(subclass):
cls.registry[type_name] = subclass
return subclass
return inner_wrapper
@staticmethod
def create(json_str) -> StorageConfig:
config_dict = json.loads(json_str)
config_type = config_dict.get(TYPE_KEY)
subclass = StorageConfigFactory.registry.get(config_type)
if subclass is None:
raise ValueError(f'Unknown config type: {config_type}')
return subclass.value_of(config_dict)

View File

@ -0,0 +1,39 @@
from typing import Type
from pywxdump.common.config.oss_config.storage_config import StorageConfig
from pywxdump.file.ConfigurableAttachment import ConfigurableAttachment
def singleton(cls):
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class OSSConfigManager:
def __init__(self):
self._config_instances = {}
self._attachment_instance = {}
def load_config(self, config: StorageConfig):
config.validate_config()
self._config_instances[config.type()] = config
# 清除旧的实例
self._attachment_instance[config.type()] = None
def get_config(self, config_type: str) -> StorageConfig:
return self._config_instances.get(config_type)
def get_attachment(self, config_type: str, instance_class: Type[ConfigurableAttachment]):
if config_type not in self._config_instances:
raise ValueError(f"Config not found: {config_type}")
if not self._attachment_instance[config_type]:
config = self._config_instances[config_type]
self._attachment_instance[config_type] = instance_class.load_config(config)
return self._attachment_instance[config_type]

View File

@ -0,0 +1,97 @@
import json
from dataclasses import dataclass
from pywxdump.common.config.oss_config.storage_config import StorageConfig
@dataclass
class ServerConfig:
"""
:param merge_path: 合并后的数据库路径 默认""
:param wx_path: 微信文件夹的路径用于显示图片 默认""
:param key: 密钥 默认""
:param my_wxid: 微信账号(本人微信id) 默认""
:param port: 端口号 默认5000
:param online: 是否在线查看(局域网查看) 默认 False
:param debug: 是否开启debug模式 默认 False
:param is_open_browser: 是否自动打开浏览器 默认 True
:param oss_config: 对象存储配置 默认 None
"""
merge_path: str = ""
wx_path: str = ""
key: str = ""
my_wxid: str = ""
port: int = 5000
online: bool = False
debug: bool = False
is_open_browser: bool = True
oss_config: dict = None
@classmethod
def builder(cls):
return ServerConfig.Builder()
class Builder:
def __init__(self):
self._merge_path = ""
self._wx_path = ""
self._key = ""
self._my_wxid = ""
self._port = 5000
self._online = False
self._debug = False
self._is_open_browser = True
self._oss_config = None
def merge_path(self, merge_path: str):
self._merge_path = merge_path
return self
def wx_path(self, wx_path: str):
self._wx_path = wx_path
return self
def key(self, key: str):
self._key = key
return self
def my_wxid(self, my_wxid: str):
self._my_wxid = my_wxid
return self
def port(self, port: int):
self._port = port
return self
def online(self, online: bool):
self._online = online
return self
def debug(self, debug: bool):
self._debug = debug
return self
def is_open_browser(self, is_open_browser: bool):
self._is_open_browser = is_open_browser
return self
def oss_config(self, oss_config: StorageConfig):
oss_config.validate_config()
self._oss_config = oss_config.get_config()
return self
def build(self):
return ServerConfig(
merge_path=self._merge_path,
wx_path=self._wx_path,
key=self._key,
my_wxid=self._my_wxid,
port=self._port,
online=self._online,
debug=self._debug,
is_open_browser=self._is_open_browser,
oss_config=self._oss_config
)
def oss_config_to_json(self) -> str:
return json.dumps(self.oss_config) if self.oss_config else None

View File

@ -38,6 +38,7 @@ class DatabaseBase:
if not AttachmentContext.isLocalPath(db_path): if not AttachmentContext.isLocalPath(db_path):
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
local_path = os.path.join(temp_dir, f"{uuid.uuid1()}.db") local_path = os.path.join(temp_dir, f"{uuid.uuid1()}.db")
logging.info(f"下载文件到本地: {db_path} -> {local_path}")
AttachmentContext.download_file(db_path, local_path) AttachmentContext.download_file(db_path, local_path)
else: else:
local_path = db_path local_path = db_path
@ -88,7 +89,23 @@ class DatabaseBase:
self._db_connection = None self._db_connection = None
if not AttachmentContext.isLocalPath(self._db_path): if not AttachmentContext.isLocalPath(self._db_path):
# 删除tmp目录下的db文件 # 删除tmp目录下的db文件
self.clearTmpDb() self._clearTmpDb()
@classmethod
def terminate_connection(cls, db_path: str):
"""
关闭数据库连接
:param db_path: 数据库路径
"""
if db_path in cls._connection_pool and cls._connection_pool[db_path]:
cls._connection_pool[db_path].close()
logging.info(f"关闭数据库 - {db_path}")
cls._connection_pool[db_path] = None
if not AttachmentContext.isLocalPath(db_path):
# 删除tmp目录下的db文件
cls._clearTmpDb()
def close_all_connection(self): def close_all_connection(self):
for db_path in self._connection_pool: for db_path in self._connection_pool:
@ -97,18 +114,18 @@ class DatabaseBase:
logging.info(f"关闭数据库 - {db_path}") logging.info(f"关闭数据库 - {db_path}")
self._connection_pool[db_path] = None self._connection_pool[db_path] = None
# 删除tmp目录下的db文件 # 删除tmp目录下的db文件
self.clearTmpDb() self._clearTmpDb()
def clearTmpDb(self):
# 清理 tmp目录下.db文件 @staticmethod
def _clearTmpDb():
# 清理 tmp目录下.db文件如果db文件存储在对象存储中使用时需要下载到tmp目录中且文件名是由uuid生成为了避免文件过多需要清理
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
db_files = glob.glob(os.path.join(temp_dir, '*.db')) db_files = glob.glob(os.path.join(temp_dir, '*.db'))
for db_file in db_files: for db_file in db_files:
try: try:
os.remove(db_file) os.remove(db_file)
print(f"Deleted: {db_file}")
except Exception as e: except Exception as e:
print(f"Error deleting {db_file}: {e}") logging.error(f"Error deleting {db_file}: {e}")
def show__singleton_instances(self): def show__singleton_instances(self):
print(self._singleton_instances) print(self._singleton_instances)

123
pywxdump/file/Attachment.py Normal file
View File

@ -0,0 +1,123 @@
from typing import Protocol, IO
# 基类
class Attachment(Protocol):
"""
附件处理协议类定义了附件处理的基本接口
"""
def exists(self, path: str) -> bool:
"""
检查文件或目录是否存在
参数:
path (str): 文件或目录路径
返回:
bool: 如果存在返回True否则返回False
"""
pass
def makedirs(self, path: str) -> bool:
"""
创建目录包括所有中间目录
参数:
path (str): 目录路径
返回:
bool: 总是返回True
"""
pass
def open(self, path: str, mode: str) -> IO:
"""
打开一个文件并返回文件对象
参数:
path (str): 文件路径
mode (str): 打开文件的模式
返回:
IO: 文件对象
"""
pass
def remove(self, path: str) -> bool:
"""
删除文件
参数:
path (str): 文件路径
返回:
bool: 是否删除成功
"""
pass
def isdir(self, path: str) -> bool:
"""
判断是否为目录
参数:
s3_url (str): 文件路径
返回:
bool: 是否为目录
"""
pass
@classmethod
def join(cls, path: str, *paths: str) -> str:
"""
连接一个或多个路径组件
参数:
path (str): 第一个路径组件
*paths (str): 其他路径组件
返回:
str: 连接后的路径
"""
pass
@classmethod
def dirname(cls, path: str) -> str:
"""
获取路径的目录名
参数:
path (str): 文件路径
返回:
str: 目录名
"""
pass
@classmethod
def basename(cls, path: str) -> str:
"""
获取路径的基本名文件名
参数:
path (str): 文件路径
返回:
str: 基本名文件名
"""
pass
def getsize(self, path) -> int:
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
pass

View File

@ -1,26 +0,0 @@
from typing import Protocol, IO
# 基类
class Attachment(Protocol):
def exists(self, path) -> bool:
pass
def makedirs(self, path) -> bool:
pass
def open(self, path, param) -> IO:
pass
@classmethod
def join(cls, __a: str, *paths: str) -> str:
pass
@classmethod
def dirname(cls, path: str) -> str:
pass
@classmethod
def basename(cls, path: str) -> str:
pass

View File

@ -1,41 +1,108 @@
import os import os
from datetime import datetime from datetime import datetime
from typing import AnyStr, BinaryIO, Callable, Union, IO from typing import AnyStr, Callable, Union, IO
from flask import send_file, Response from flask import send_file, Response
from pywxdump.file.AttachmentAbstract import Attachment from pywxdump.common.config.oss_config.s3_config import S3Config
from pywxdump.common.config.oss_config_manager import OSSConfigManager
from pywxdump.file.Attachment import Attachment
from pywxdump.file.LocalAttachment import LocalAttachment from pywxdump.file.LocalAttachment import LocalAttachment
from pywxdump.file.S3Attachment import S3Attachment from pywxdump.file.S3Attachment import S3Attachment
def determine_strategy(file_path: str) -> Attachment: def determine_strategy(file_path: str) -> Attachment:
if file_path.startswith("s3://"): """
return S3Attachment() 根据文件路径确定使用的附件策略本地或S3
参数:
file_path (str): 文件路径
返回:
Attachment: 返回对应的附件策略类实例
"""
if file_path.startswith(f"s3://"):
return OSSConfigManager().get_attachment("s3", S3Attachment)
else: else:
return LocalAttachment() return LocalAttachment()
def exists(path: str) -> bool: def exists(path: str) -> bool:
"""
检查文件或目录是否存在
参数:
path (str): 文件或目录路径
返回:
bool: 如果存在返回True否则返回False
"""
return determine_strategy(path).exists(path) return determine_strategy(path).exists(path)
def open_file(path: str, mode: str) -> IO: def open_file(path: str, mode: str) -> IO:
"""
打开一个文件并返回文件对象
参数:
path (str): 文件路径
mode (str): 打开文件的模式
返回:
IO: 文件对象
"""
return determine_strategy(path).open(path, mode) return determine_strategy(path).open(path, mode)
def makedirs(path: str) -> bool: def makedirs(path: str) -> bool:
"""
创建目录包括所有中间目录
参数:
path (str): 目录路径
返回:
bool: 总是返回True
"""
return determine_strategy(path).makedirs(path) return determine_strategy(path).makedirs(path)
def join(__a: str, *paths: str) -> str: def join(path: str, *paths: str) -> str:
return determine_strategy(__a).join(__a, *paths) """
连接一个或多个路径组件
参数:
path (str): 第一个路径组件
*paths (str): 其他路径组件
返回:
str: 连接后的路径
"""
return determine_strategy(path).join(path, *paths)
def dirname(path: str) -> str: def dirname(path: str) -> str:
"""
获取路径的目录名
参数:
path (str): 文件路径
返回:
str: 目录名
"""
return determine_strategy(path).dirname(path) return determine_strategy(path).dirname(path)
def basename(path: str) -> str: def basename(path: str) -> str:
"""
获取路径的基本名文件名
参数:
path (str): 文件路径
返回:
str: 基本名文件名
"""
return determine_strategy(path).basename(path) return determine_strategy(path).basename(path)
@ -49,6 +116,22 @@ def send_attachment(
last_modified: Union[datetime, int, float, None] = None, last_modified: Union[datetime, int, float, None] = None,
max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None, max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None,
) -> Response: ) -> Response:
"""
发送附件文件
参数:
path_or_file (Union[os.PathLike[AnyStr], str]): 文件路径或文件对象
mimetype (Union[str, None]): 文件的MIME类型
as_attachment (bool): 是否作为附件下载
download_name (Union[str, None]): 下载时的文件名
conditional (bool): 是否使用条件请求
etag (Union[bool, str]): ETag值
last_modified (Union[datetime, int, float, None]): 最后修改时间
max_age (Union[None, int, Callable[[Union[str, None]], Union[int, None]]]): 缓存最大时间
返回:
Response: Flask的响应对象
"""
file_io = open_file(path_or_file, "rb") file_io = open_file(path_or_file, "rb")
# 如果没有提供 download_name 或 mimetype则从 path_or_file 中获取文件名和 MIME 类型 # 如果没有提供 download_name 或 mimetype则从 path_or_file 中获取文件名和 MIME 类型
@ -61,6 +144,16 @@ def send_attachment(
def download_file(db_path, local_path): def download_file(db_path, local_path):
"""
从db_path下载文件到local_path
参数:
db_path (str): 数据库文件路径
local_path (str): 本地文件路径
返回:
str: 本地文件路径
"""
with open(local_path, 'wb') as f: with open(local_path, 'wb') as f:
with open_file(db_path, 'rb') as r: with open_file(db_path, 'rb') as r:
f.write(r.read()) f.write(r.read())
@ -68,5 +161,28 @@ def download_file(db_path, local_path):
def isLocalPath(path: str) -> bool: def isLocalPath(path: str) -> bool:
return isinstance(determine_strategy(path), LocalAttachment) """
判断路径是否为本地路径
参数:
path (str): 文件或目录路径
返回:
bool: 如果是本地路径返回True否则返回False
"""
strategy = determine_strategy(path)
return isinstance(strategy, type(LocalAttachment()))
def getsize(path: str):
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
return determine_strategy(path).getsize(path)

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from pywxdump.common.config.oss_config import storage_config
from pywxdump.file.Attachment import Attachment
class ConfigurableAttachment(ABC, Attachment):
@classmethod
@abstractmethod
def load_config(cls, config: storage_config) -> Attachment:
"""设置配置"""
pass

View File

@ -3,37 +3,139 @@ import os
import sys import sys
from typing import IO from typing import IO
from pywxdump.file.Attachment import Attachment
class LocalAttachment:
def singleton(cls):
instances = {}
def create_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return create_instance
@singleton
class LocalAttachment(Attachment):
def open(self, path, mode) -> IO: def open(self, path, mode) -> IO:
"""
打开一个文件并返回文件对象
参数:
path (str): 文件路径
mode (str): 打开文件的模式
返回:
IO: 文件对象
"""
path = self.dealLocalPath(path) path = self.dealLocalPath(path)
return open(path, mode) return open(path, mode)
def remove(self, path: str) -> bool:
"""
删除文件
参数:
path (str): 文件路径
返回:
bool: 是否删除成功
"""
path = self.dealLocalPath(path)
if not self.exists(path):
raise FileNotFoundError(f"File not found: {path}")
if self.isdir(path):
raise ValueError(f"Path is not a file: {path}")
os.remove(path)
return True
def exists(self, path) -> bool: def exists(self, path) -> bool:
"""
检查文件或目录是否存在
参数:
path (str): 文件或目录路径
返回:
bool: 如果存在返回True否则返回False
"""
path = self.dealLocalPath(path) path = self.dealLocalPath(path)
return os.path.exists(path) return os.path.exists(path)
def makedirs(self, path) -> bool: def makedirs(self, path) -> bool:
"""
创建目录包括所有中间目录
参数:
path (str): 目录路径
返回:
bool: 总是返回True
"""
path = self.dealLocalPath(path) path = self.dealLocalPath(path)
os.makedirs(path) os.makedirs(path)
return True return True
@classmethod @classmethod
def join(cls, __a: str, *paths: str) -> str: def join(cls, path: str, *paths: str) -> str:
return os.path.join(__a, *paths) """
连接一个或多个路径组件
参数:
path (str): 第一个路径组件
*paths (str): 其他路径组件
返回:
str: 连接后的路径
"""
# 使用os.path.join连接路径
return os.path.join(path, *paths)
@classmethod @classmethod
def dirname(cls, path: str) -> str: def dirname(cls, path: str) -> str:
"""
获取路径的目录名
参数:
path (str): 文件路径
返回:
str: 目录名
"""
# 获取路径的目录名
return os.path.dirname(path) return os.path.dirname(path)
@classmethod @classmethod
def basename(cls, path: str) -> str: def basename(cls, path: str) -> str:
"""
获取路径的基本名文件名
参数:
path (str): 文件路径
返回:
str: 基本名文件名
"""
# 获取路径的基本名
return os.path.basename(path) return os.path.basename(path)
def dealLocalPath(self, path: str) -> str: def dealLocalPath(self, path: str) -> str:
# 获取当前系统的地址分隔符 """
# 将path中的 /替换为当前系统的分隔符 处理本地路径替换路径中的分隔符并根据操作系统进行特殊处理
参数:
path (str): 文件路径
返回:
str: 处理后的路径
"""
# 获取当前系统的路径分隔符
# 将path中的 / 替换为当前系统的分隔符
path = path.replace('/', os.sep) path = path.replace('/', os.sep)
if sys.platform == "win32": if sys.platform == "win32":
# 如果是windows系统且路径长度超过260个字符 # 如果是windows系统且路径长度超过260个字符
@ -44,3 +146,52 @@ class LocalAttachment:
return path return path
else: else:
return path return path
def isdir(self, path: str) -> bool:
"""
判断是否为目录
参数:
path (str): 文件路径
返回:
bool: 是否为目录
"""
# 判断路径是否为目录
return os.path.isdir(path)
def getsize(self, path) -> int:
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
if not self.exists(path):
raise FileNotFoundError(f"File not found: {path}")
if os.path.isfile(path):
return os.path.getsize(path)
else:
return self._get_dir_size(path)
def _get_dir_size(self, path):
"""
计算目录大小
参数:
path (str): 目录路径
返回:
int: 目录大小
"""
total_size = 0
for firePath, surnames, filenames in os.walk(path):
for f in filenames:
fp = self.join(firePath, f)
total_size += os.path.getsize(fp)
return total_size

View File

@ -1,89 +1,165 @@
# 对象存储文件处理类(示例:假设是 AWS S3 # 对象存储文件处理类(示例:假设是 AWS S3
import os import os
from typing import IO from typing import IO
from urllib.parse import urlparse from urllib.parse import urlparse, urljoin
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from smart_open import open from smart_open import open
import boto3 import boto3
from botocore.client import Config from botocore.client import Config
class S3Attachment: from pywxdump.common.config.oss_config import storage_config
from pywxdump.file.Attachment import Attachment
from pywxdump.file.ConfigurableAttachment import ConfigurableAttachment
def __init__(self):
# 腾讯云 COS 配置 class S3Attachment(ConfigurableAttachment):
self.cos_endpoint = "https://cos.<your-region>.myqcloud.com" # 替换 <your-region> 为你的 COS 区域,例如 ap-shanghai
self.access_key_id = "SecretId" # 替换为你的腾讯云 SecretId def __init__(self, s3_config: storage_config):
self.secret_access_key = "SecretKey" # 替换为你的腾讯云 SecretKey # S3 配置
self.s3_config = s3_config
# 校验配置
s3_config.validate_config()
# 创建 S3 客户端 # 创建 S3 客户端
self.s3_client = boto3.client( self.s3_client = boto3.client(
's3', 's3',
endpoint_url=self.cos_endpoint, endpoint_url=s3_config.endpoint_url,
aws_access_key_id=self.access_key_id, aws_access_key_id=s3_config.access_key,
aws_secret_access_key=self.secret_access_key, aws_secret_access_key=s3_config.secret_key,
config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'}) config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'})
) )
def exists(self, path) -> bool: @classmethod
bucket_name, path = self.dealS3Url(path) def load_config(cls, config: storage_config) -> Attachment:
# 检查是否为目录 return cls(config)
if path.endswith('/'):
# 尝试列出该路径下的对象
try:
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
if 'Contents' in response:
return True
else:
return False
except ClientError as e:
print(f"Error: {e}")
return False
else:
# 检查是否为文件
try:
self.s3_client.head_object(Bucket=bucket_name, Key=path)
return True
except ClientError as e:
if e.response['Error']['Code'] == '404':
return False
else:
print(f"Error: {e}")
return False
def makedirs(self, path) -> bool: def exists(self, s3_url) -> bool:
if not self.exists(path): """
bucket_name, path = self.dealS3Url(path) 检查对象是否存在
参数:
s3_url (str): 对象路径
返回:
bool: 是否存在
"""
bucket_name, path = self.dealS3Url(s3_url)
# 尝试列出该路径下的对象
try:
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
if 'Contents' in response:
return True
else:
return False
except ClientError as e:
print(f"Error: {e}")
return False
def makedirs(self, s3_url) -> bool:
"""
创建目录
参数:
s3_url (str): 目录路径
返回:
bool: 是否创建成功
"""
if not self.exists(s3_url):
bucket_name, path = self.dealS3Url(s3_url)
self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/') self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/')
return True return True
def open(self, path, mode) -> IO: def open(self, s3_url, mode) -> IO:
self.dealS3Url(path) """
return open(uri=path, mode=mode, transport_params={'client': self.s3_client}) 打开文件
参数:
s3_url (str): 文件路径
mode (str): 打开模式
返回:
IO: 文件对象
"""
return open(uri=s3_url, mode=mode, transport_params={'client': self.s3_client})
def remove(self, s3_url: str) -> bool:
"""
删除文件
参数:
s3_url (str): 文件路径
返回:
bool: 是否删除成功
"""
if not self.exists(s3_url):
raise FileNotFoundError(f"File not found: {s3_url}")
if self.isdir(s3_url):
raise ValueError(f"Path is not a file: {s3_url}")
bucket_name, path = self.dealS3Url(s3_url)
self.s3_client.delete_object(Bucket=bucket_name, Key=path)
return True
@classmethod @classmethod
def join(cls, __a: str, *paths: str) -> str: def join(cls, s3_url: str, *paths: str) -> str:
return os.path.join(__a, *paths) """
连接路径
参数:
s3_url (str): 路径
*paths (str): 路径
返回:
str: 连接后的路径
"""
# 使用os.path.join连接路径
path = os.path.join(s3_url, *paths)
# 将所有反斜杠替换为正斜杠
return path.replace('\\', '/')
@classmethod @classmethod
def dirname(cls, path: str) -> str: def dirname(cls, s3_url: str) -> str:
return os.path.dirname(path) """
返回路径的目录部分
参数:
s3_url (str): 路径
返回:
str: 路径的目录部分
"""
return os.path.dirname(s3_url)
@classmethod @classmethod
def basename(cls, path: str) -> str: def basename(cls, s3_url: str) -> str:
return os.path.basename(path) """
返回路径的最后一个元素
def dealS3Url(self, path: str) -> object: 参数:
s3_url (str): 路径
返回:
str: 路径的最后一个元素
"""
return os.path.basename(s3_url)
def dealS3Url(self, s3_url: str) -> object:
""" """
解析 S3 URL 并返回存储桶名称和路径 解析 S3 URL 并返回存储桶名称和路径
参数: 参数:
path (str): S3 URL s3_url (str): S3 URL
返回: 返回:
tuple: 包含存储桶名称和路径的元组 tuple: 包含存储桶名称和路径的元组
""" """
parsed_url = urlparse(path) parsed_url = urlparse(s3_url)
# 确保URL是S3 URL # 确保URL是S3 URL
if parsed_url.scheme != 's3': if parsed_url.scheme != 's3':
@ -94,3 +170,75 @@ class S3Attachment:
return bucket_name, s3_path return bucket_name, s3_path
def isdir(self, s3_url: str) -> bool:
"""
判断是否为目录
参数:
s3_url (str): 文件路径
返回:
bool: 是否为目录
"""
# 确保目录路径以'/'结尾
if not s3_url.endswith('/'):
s3_url += '/'
bucket_name, path = self.dealS3Url(s3_url)
# 列出以该 key 为前缀的对象
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
if 'Contents' in response:
# 存在对象,判断是否为目录
if response['Contents'][0]['Key'] == path or not path.endswith('/'):
return False
else:
return True
else:
return False
def getsize(self, s3_url) -> int:
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
if not self.exists(s3_url):
raise FileNotFoundError(f"File not found: {s3_url}")
if self.isdir(s3_url):
return self._get_size_of_directory(s3_url)
else:
bucket_name, path = self.dealS3Url(s3_url)
response = self.s3_client.head_object(Bucket=bucket_name, Key=path)
return response['ContentLength']
def _get_size_of_directory(self, s3_url):
"""
获取目录大小
参数:
s3_url (str): 目录路径
返回:
int: 目录大小
"""
bucket_name, path = self.dealS3Url(s3_url)
total_size = 0
# 确保目录路径以'/'结尾
if not path.endswith('/'):
path += '/'
# 列出指定目录中的对象
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path)
if 'Contents' in response:
for obj in response['Contents']:
total_size += obj['Size']
return total_size

View File

@ -0,0 +1,4 @@
if __name__ == '__main__':
pass

View File

@ -10,9 +10,11 @@ import subprocess
import sys import sys
import time import time
from pywxdump.common.config.oss_config.s3_config import S3Config
from pywxdump.common.config.server_config import ServerConfig
def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online=False, debug=False,
isopenBrowser=True): def start_falsk(server_config: ServerConfig):
""" """
启动flask 启动flask
:param merge_path: 合并后的数据库路径 :param merge_path: 合并后的数据库路径
@ -38,13 +40,14 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online
import logging import logging
# 检查端口是否被占用 # 检查端口是否被占用
if online: if server_config.online:
host = '0.0.0.0' host = '0.0.0.0'
else: else:
host = "127.0.0.1" host = "127.0.0.1"
app = Flask(__name__, template_folder='./ui/web', static_folder='./ui/web/assets/', static_url_path='/assets/') app = Flask(__name__, template_folder='./ui/web', static_folder='./ui/web/assets/', static_url_path='/assets/')
app.config['ENV'] = 'development'
app.config['DEBUG'] = True
# 设置超时时间为 1000 秒 # 设置超时时间为 1000 秒
app.config['TIMEOUT'] = 1000 app.config['TIMEOUT'] = 1000
app.secret_key = 'secret_key' app.secret_key = 'secret_key'
@ -67,18 +70,20 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online
g.tmp_path = tmp_path # 临时文件夹,用于存放图片等 g.tmp_path = tmp_path # 临时文件夹,用于存放图片等
g.sf = session_file # 用于存放各种基础信息 g.sf = session_file # 用于存放各种基础信息
if merge_path: save_session(session_file, "test", "merge_path", merge_path) wxid = server_config.my_wxid if server_config.my_wxid else "test"
if wx_path: save_session(session_file, "test", "wx_path", wx_path) if server_config.merge_path: save_session(session_file, wxid, "merge_path", server_config.merge_path)
if key: save_session(session_file, "test", "key", key) if server_config.wx_path: save_session(session_file, wxid, "wx_path", server_config.wx_path)
if my_wxid: save_session(session_file, "test", "my_wxid", my_wxid) if server_config.key: save_session(session_file, wxid, "key", server_config.key)
if server_config.my_wxid: save_session(session_file, wxid, "my_wxid", server_config.my_wxid)
if server_config.oss_config: save_session(session_file, wxid, "oss_config", server_config.oss_config_to_json())
if not os.path.exists(session_file): if not os.path.exists(session_file):
save_session(session_file, "test", "last", my_wxid) save_session(session_file, wxid, "last", server_config.my_wxid)
app.register_blueprint(api) app.register_blueprint(api)
if isopenBrowser: if server_config.is_open_browser:
try: try:
# 自动打开浏览器 # 自动打开浏览器
url = f"http://127.0.0.1:{port}/" url = f"http://127.0.0.1:{server_config.port}/"
# 根据操作系统使用不同的命令打开默认浏览器 # 根据操作系统使用不同的命令打开默认浏览器
if sys.platform.startswith('darwin'): # macOS if sys.platform.startswith('darwin'): # macOS
subprocess.call(['open', url]) subprocess.call(['open', url])
@ -100,13 +105,13 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online
return True return True
return False return False
if is_port_in_use(host, port): if is_port_in_use(host, server_config.port):
print(f"Port {port} is already in use. Choose a different port.") print(f"Port {server_config.port} is already in use. Choose a different port.")
input("Press Enter to exit...") input("Press Enter to exit...")
else: else:
time.sleep(1) time.sleep(1)
print("[+] 请使用浏览器访问 http://127.0.0.1:5000/ 查看聊天记录") print("[+] 请使用浏览器访问 http://127.0.0.1:5000/ 查看聊天记录")
app.run(host=host, port=port, debug=debug) app.run(host=host, port=server_config.port, debug=server_config.debug)
if __name__ == '__main__': if __name__ == '__main__':
@ -114,6 +119,16 @@ if __name__ == '__main__':
wx_path = r"****" wx_path = r"****"
my_wxid = "****" my_wxid = "****"
server_config = ServerConfig.builder()
server_config.merge_path("s3://*********-1256220500/*********/merge_all.db")
server_config.wx_path("s3://*********-1256220500/*********")
server_config.my_wxid("test")
server_config.port(9000)
server_config.online(True)
server_config.is_open_browser(False)
s3Config = S3Config("AKIDaAjA*********I1kR4gFdv67v", "wlT2ldSBk*********Qh4fEev47",
"https://cos.ap-nanjing.myqcloud.com")
server_config.oss_config(s3Config)
start_falsk(server_config.build())
start_falsk(merge_path=merge_path, wx_path=wx_path, my_wxid=my_wxid,
port=5000, online=False, debug=False, isopenBrowser=False)

View File

@ -395,7 +395,7 @@ def get_wechat_db(require_list: Union[List[str], str] = "all", msg_dir: str = No
def get_core_db(wx_path: str, db_type: list = None) -> [str]: def get_core_db(wx_path: str, db_type: list = None) -> [str]:
""" """
获取聊天消息核心数据库路径 获取聊天消息核心数据库路径
:param wx_path: 微信文件夹路径 egC:\*****\WeChat Files\wxid******* :param wx_path: 微信文件夹路径 egC:\\*****\\WeChat Files\\wxid*******
:param db_type: 数据库类型 eg: ["MSG", "MediaMSG", "MicroMsg"]三个中选择一个或多个 :param db_type: 数据库类型 eg: ["MSG", "MediaMSG", "MicroMsg"]三个中选择一个或多个
:return: 返回数据库路径 eg:["",""] :return: 返回数据库路径 eg:["",""]
""" """

View File

@ -304,7 +304,7 @@ def decrypt_merge(wx_path, key, outpath="", CreateTime: int = 0, endCreateTime:
bool, str): bool, str):
""" """
解密合并数据库 msg.db, microMsg.db, media.db,注意会删除原数据库 解密合并数据库 msg.db, microMsg.db, media.db,注意会删除原数据库
:param wx_path: 微信路径 eg: C:\*******\WeChat Files\wxid_********* :param wx_path: 微信路径 eg: C:\\*******\\WeChat Files\\wxid_*********
:param key: 解密密钥 :param key: 解密密钥
:return: (true,解密后的数据库路径) or (false,错误信息) :return: (true,解密后的数据库路径) or (false,错误信息)
""" """
@ -430,7 +430,7 @@ def all_merge_real_time_db(key, wx_path, merge_path):
这是全量合并会有可能产生重复数据需要自行去重 这是全量合并会有可能产生重复数据需要自行去重
:param key: 解密密钥 :param key: 解密密钥
:param wx_path: 微信路径 :param wx_path: 微信路径
:param merge_path: 合并后的数据库路径 eg: C:\*******\WeChat Files\wxid_*********\merge.db :param merge_path: 合并后的数据库路径 eg: C:\\*******\\WeChat Files\\wxid_*********\\merge.db
:return: :return:
""" """
if not merge_path or not key or not wx_path or not wx_path: if not merge_path or not key or not wx_path or not wx_path:

View File

@ -16,4 +16,5 @@ lxml
flask_cors flask_cors
pandas pandas
smart_open[s3] smart_open[s3]
boto3 boto3
botocore

View File

@ -0,0 +1,63 @@
import os
import unittest
from unittest.mock import patch
from pywxdump.file.LocalAttachment import LocalAttachment
import tempfile
class TestLocalAttachment(unittest.TestCase):
def setUp(self):
self.attachment = LocalAttachment()
self.test_file_path = self.attachment.join(tempfile.gettempdir(),"test.txt")
self.test_dir_path = self.attachment.join(tempfile.gettempdir(),"test_dir")
def tearDown(self):
if os.path.exists(self.test_file_path):
os.remove(self.test_file_path)
if os.path.exists(self.test_dir_path):
for dirpath, dirnames, filenames in os.walk(self.test_dir_path):
for f in filenames:
fp = os.path.join(dirpath, f)
os.remove(fp)
os.rmdir(self.test_dir_path)
def test_remove_existing_file(self):
with open(self.test_file_path, 'w') as f:
f.write("test")
self.assertTrue(self.attachment.remove(self.test_file_path))
self.assertFalse(os.path.exists(self.test_file_path))
def test_remove_non_existing_file(self):
with self.assertRaises(FileNotFoundError):
self.attachment.remove(self.test_file_path)
@patch('os.remove')
def test_remove_os_error(self, mock_remove):
mock_remove.side_effect = OSError
with self.assertRaises(OSError):
self.attachment.remove(self.test_file_path)
def test_getsize_existing_file(self):
# 清理测试环境
self.tearDown()
with open(self.test_file_path, "w") as f:
f.write("Hello, World!")
self.assertEqual(self.attachment.getsize(self.test_file_path), 13)
def test_getsize_non_existing_file(self):
with self.assertRaises(FileNotFoundError):
self.attachment.getsize('non_existing_file')
def test_getsize_existing_folder(self):
# 清理测试环境
self.tearDown()
os.mkdir(self.test_dir_path)
with open(os.path.join(self.test_dir_path, "file1.txt"), "w") as f:
f.write("Hello, World!")
with open(os.path.join(self.test_dir_path, "file2.txt"), "w") as f:
f.write("Hello, World!")
self.assertEqual(self.attachment.getsize(self.test_dir_path), 26)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,71 @@
import unittest
from unittest.mock import patch, MagicMock
from botocore.exceptions import ClientError
import boto3
from pywxdump.file.S3Attachment import S3Attachment
class TestS3Attachment(unittest.TestCase):
@patch('boto3.client')
def setUp(self, mock_client):
self.mock_client = MagicMock()
mock_client.return_value = self.mock_client
s3_config = MagicMock()
self.attachment = S3Attachment(s3_config)
self.test_s3_url = "s3://test_bucket/test_file"
self.test_s3_dir = "s3://test_bucket/test_folder/"
@patch.object(S3Attachment, 'exists')
@patch.object(S3Attachment, 'isFolder')
@patch('boto3.client')
def test_removal_of_existing_file(self, mock_client, mock_isFolder, mock_exists):
mock_exists.return_value = True
mock_isFolder.return_value = False
mock_client.return_value = MagicMock()
self.assertTrue(self.attachment.remove(self.test_s3_url))
@patch.object(S3Attachment, 'exists')
def test_removal_of_non_existing_file(self, mock_exists):
mock_exists.return_value = False
with self.assertRaises(FileNotFoundError):
self.attachment.remove(self.test_s3_url)
@patch.object(S3Attachment, 'exists')
@patch.object(S3Attachment, 'isFolder')
def test_removal_of_folder_instead_of_file(self, mock_isFolder, mock_exists):
mock_exists.return_value = True
mock_isFolder.return_value = True
with self.assertRaises(ValueError):
self.attachment.remove(self.test_s3_url)
@patch.object(S3Attachment, 'exists')
@patch.object(S3Attachment, 'isdir')
def test_removal_with_s3_error(self, mock_isdir, mock_exists):
mock_exists.return_value = True
mock_isdir.return_value = False
# 模拟 ClientError
error_response = {'Error': {'Code': 'InvalidRequest', 'Message': 'Some error message'}}
self.mock_client.delete_object.side_effect = ClientError(error_response, 'delete_object')
with self.assertRaises(ClientError):
self.attachment.remove(self.test_s3_url)
@patch.object(S3Attachment, 'exists')
def test_getsize_existing_file(self, mock_exists):
mock_exists.return_value = True
self.mock_client.head_object.return_value = {'ContentLength': 100}
self.assertEqual(self.attachment.getsize(self.test_s3_url), 100)
def test_getsize_non_existing_file(self):
with self.assertRaises(FileNotFoundError):
self.attachment.getsize('non_existing_file')
@patch.object(S3Attachment, 'isdir')
def test_getsize_existing_folder(self, mock_isdir):
mock_isdir.return_value = True
self.mock_client.list_objects_v2.return_value = {'Contents': [{'Size': 100}, {'Size': 300}]}
self.assertEqual(self.attachment.getsize(self.test_s3_dir), 400)
if __name__ == '__main__':
unittest.main()