1.增加获取文件大小、文件删除方法,并补充注释

2.增加对象存储的配置实现
This commit is contained in:
cllcode 2024-07-13 21:56:42 +08:00
parent 2b3c3f9d1f
commit 4de1c26f25
22 changed files with 1222 additions and 150 deletions

View File

@ -5,14 +5,14 @@
# Author: xaoyaoo
# Date: 2024/01/02
# -------------------------------------------------------------------------------
import base64
import json
import logging
import os
import re
import time
import shutil
import sys
from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory
from pywxdump.common.config.oss_config_manager import OSSConfigManager
pythoncom = __import__('pythoncom') if sys.platform == "win32" else None
import pywxdump
from pywxdump.file import AttachmentContext
@ -55,8 +55,8 @@ def init_last():
是否初始化
:return:
"""
my_wxid = request.json.get("my_wxid", "")
my_wxid = my_wxid.strip().strip("'").strip('"') if isinstance(my_wxid, str) else ""
my_wxid_dict = request.json.get("my_wxid", {})
my_wxid = my_wxid_dict.get("wxid", "")
if not my_wxid:
my_wxid = read_session(g.sf, "test", "last")
if my_wxid:
@ -64,6 +64,9 @@ def init_last():
merge_path = read_session(g.sf, my_wxid, "merge_path")
wx_path = read_session(g.sf, my_wxid, "wx_path")
key = read_session(g.sf, my_wxid, "key")
# 如果有oss_config则设置对象存储配置
oss_config = read_session(g.sf, my_wxid, "oss_config")
ossConfig(oss_config)
rdata = {
"merge_path": merge_path,
"wx_path": wx_path,
@ -96,11 +99,15 @@ def init_key():
return ReJson(1002, body=f"my_wxid is required: {my_wxid}")
old_merge_save_path = read_session(g.sf, my_wxid, "merge_path")
if isinstance(old_merge_save_path, str) and old_merge_save_path and AttachmentContext.exists(old_merge_save_path):
pmsg = ParsingMSG(old_merge_save_path)
pmsg.close_all_connection()
# 如果有oss_config则设置对象存储配置
oss_config = read_session(g.sf, my_wxid, "oss_config")
ossConfig(oss_config)
# 如果存在旧地连接则关闭连接
if isinstance(old_merge_save_path, str) and old_merge_save_path:
ParsingMSG.terminate_connection(old_merge_save_path)
out_path = AttachmentContext.join(g.tmp_path, "decrypted", my_wxid) if my_wxid else AttachmentContext.join(g.tmp_path, "decrypted")
out_path = AttachmentContext.join(g.tmp_path, "decrypted", my_wxid) if my_wxid else AttachmentContext.join(
g.tmp_path, "decrypted")
# 检查文件夹中文件是否被占用
if AttachmentContext.exists(out_path):
try:
@ -115,7 +122,7 @@ def init_key():
if code:
# 移动merge_save_path到g.tmp_path/my_wxid
if not AttachmentContext.exists(AttachmentContext.join(g.tmp_path, my_wxid)):
os.makedirs(AttachmentContext.join(g.tmp_path, my_wxid))
AttachmentContext.makedirs(AttachmentContext.join(g.tmp_path, my_wxid))
merge_save_path_new = AttachmentContext.join(g.tmp_path, my_wxid, "merge_all.db")
shutil.move(merge_save_path, str(merge_save_path_new))
@ -144,6 +151,19 @@ def init_key():
return ReJson(2001, body=merge_save_path)
def ossConfig(oss_config: str):
"""
设置对象存储配置
:param oss_config: 对象存储配置
:return: None
"""
if oss_config:
storageConfig = StorageConfigFactory.create(oss_config)
OSSConfigManager().load_config(storageConfig)
@api.route('/api/init_nokey', methods=["GET", 'POST'])
@error9999
def init_nokey():
@ -154,6 +174,9 @@ def init_nokey():
merge_path = request.json.get("merge_path", "").strip().strip("'").strip('"')
wx_path = request.json.get("wx_path", "").strip().strip("'").strip('"')
my_wxid = request.json.get("my_wxid", "").strip().strip("'").strip('"')
# 如果有oss_config则设置对象存储配置
oss_config = request.json.get("oss_config", "").strip().strip("'").strip('"')
ossConfig(oss_config)
if not wx_path:
return ReJson(1002, body=f"wx_path is required: {wx_path}")
@ -170,6 +193,7 @@ def init_nokey():
save_session(g.sf, my_wxid, "wx_path", wx_path)
save_session(g.sf, my_wxid, "key", key)
save_session(g.sf, my_wxid, "my_wxid", my_wxid)
save_session(g.sf, my_wxid, "oss_config", oss_config)
save_session(g.sf, "test", "last", my_wxid)
rdata = {
"merge_path": merge_path,
@ -181,6 +205,25 @@ def init_nokey():
return ReJson(0, rdata)
# 查询支持的对象存储,及配置
@api.route('/api/check_storage_type', methods=["GET", 'POST'])
@error9999
def check_storage_type():
path = request.json.get("path", "").strip().strip("'").strip('"')
if not path:
return ReJson(1002, body=f"path is required: {path}")
selectStorageConfig = None
for key, storageConfig in StorageConfigFactory.registry.items():
if storageConfig.isSupported(path):
selectStorageConfig = storageConfig
if selectStorageConfig:
return ReJson(0, {"is_supported": True, "storage_type": selectStorageConfig.type(),
"config_items": selectStorageConfig.describe()})
else:
return ReJson(0, {"is_supported": False})
# END 以上为初始化相关 ***************************************************************************************************
@ -338,7 +381,7 @@ def get_imgsrc(imgsrc):
img_tmp_path = AttachmentContext.join(g.tmp_path, my_wxid, "imgsrc")
if not AttachmentContext.exists(img_tmp_path):
os.makedirs(img_tmp_path)
AttachmentContext.makedirs(img_tmp_path)
file_name = imgsrc.replace("http://", "").replace("https://", "").replace("/", "_").replace("?", "_")
file_name = file_name + ".jpg"
# 如果文件名过长,则将文件明分为目录和文件名
@ -434,7 +477,7 @@ def get_video(videoPath):
# 复制文件到临时文件夹
video_save_path = AttachmentContext.join(video_tmp_path, videoPath)
if not AttachmentContext.exists(AttachmentContext.dirname(video_save_path)):
os.makedirs(AttachmentContext.dirname(video_save_path))
AttachmentContext.makedirs(AttachmentContext.dirname(video_save_path))
if AttachmentContext.exists(video_save_path):
return AttachmentContext.send_attachment(video_save_path)
AttachmentContext.download_file(original_img_path, video_save_path)
@ -457,7 +500,7 @@ def get_audio(savePath):
# 判断savePath路径的文件夹是否存在
if not AttachmentContext.exists(AttachmentContext.dirname(savePath)):
os.makedirs(AttachmentContext.dirname(savePath))
AttachmentContext.makedirs(AttachmentContext.dirname(savePath))
parsing_media_msg = ParsingMediaMSG(merge_path)
wave_data = parsing_media_msg.get_audio(MsgSvrID, is_play=False, is_wave=True, save_path=savePath, rate=24000)
@ -484,8 +527,8 @@ def get_file_info():
all_file_path = AttachmentContext.join(wx_path, file_path)
if not AttachmentContext.exists(all_file_path):
return ReJson(5002)
file_name = os.path.basename(all_file_path)
file_size = os.path.getsize(all_file_path)
file_name = AttachmentContext.basename(all_file_path)
file_size = AttachmentContext.getsize(all_file_path)
return ReJson(0, {"file_name": file_name, "file_size": str(file_size)})
@ -528,12 +571,12 @@ def get_export_endb():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "endb")
if not AttachmentContext.exists(outpath):
os.makedirs(outpath)
AttachmentContext.makedirs(outpath)
for wxdb in wxdbpaths:
# 复制wxdb->outpath, os.path.basename(wxdb)
assert isinstance(outpath, str) # 为了解决pycharm的警告, 无实际意义
shutil.copy(wxdb, AttachmentContext.join(outpath, os.path.basename(wxdb)))
shutil.copy(wxdb, AttachmentContext.join(outpath, AttachmentContext.basename(wxdb)))
return ReJson(0, body=outpath)
@ -558,7 +601,7 @@ def get_export_dedb():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "dedb")
if not AttachmentContext.exists(outpath):
os.makedirs(outpath)
AttachmentContext.makedirs(outpath)
code, merge_save_path = decrypt_merge(wx_path=wx_path, key=key, outpath=outpath)
time.sleep(1)
@ -589,7 +632,7 @@ def get_export_csv():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "csv", wxid)
if not AttachmentContext.exists(outpath):
os.makedirs(outpath)
AttachmentContext.makedirs(outpath)
code, ret = export_csv(wxid, outpath, read_session(g.sf, my_wxid, "merge_path"))
if code:
@ -613,7 +656,7 @@ def get_export_json():
outpath = AttachmentContext.join(g.tmp_path, "export", my_wxid, "json", wxid)
if not AttachmentContext.exists(outpath):
os.makedirs(outpath)
AttachmentContext.makedirs(outpath)
code, ret = export_json(wxid, outpath, read_session(g.sf, my_wxid, "merge_path"))
if code:

View File

@ -6,11 +6,18 @@
# Date: 2023/10/14
# -------------------------------------------------------------------------------
import argparse
import json
import os
import sys
import time
from pywxdump import *
import pywxdump
from pywxdump.common.config.oss_config.storage_config import DescriptionBuilder
from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory
from pywxdump.common.config.oss_config_manager import OSSConfigManager
from pywxdump.common.config.server_config import ServerConfig
from pywxdump.common.constants import VERSION_LIST_PATH
from pywxdump.file import AttachmentContext
wxdump_ascii = r"""
@ -264,31 +271,63 @@ class MainShowChatRecords(BaseSubMainClass):
default="", metavar="")
parser.add_argument("--online", action='store_true', help="(可选)是否在线查看(局域网查看)", required=False,
default=False)
for key, storageConfig in StorageConfigFactory.registry.items():
storageConfigDescribe = storageConfig.describe()
for description in storageConfigDescribe:
parser.add_argument(
f"--{key}_{description[DescriptionBuilder.KEY]}",
help=description[DescriptionBuilder.PLACEHOLDER],
default=False, required=False, type=str, metavar=""
)
# parser.add_argument("-k", "--key", type=str, help="(可选)密钥", required=False, metavar="")
return parser
def run(self, args):
print(f"[*] PyWxDump v{pywxdump.__version__}")
server_config = ServerConfig.builder()
server_config.merge_path(args.merge_path)
server_config.wx_path(args.wx_path)
server_config.my_wxid(args.my_wxid)
server_config.online(args.online)
server_config.port(9000)
for key, storageConfig in StorageConfigFactory.registry.items():
storageConfigDescribe = storageConfig.describe()
config = {}
for description in storageConfigDescribe:
value = getattr(args, f"{key}_{description[DescriptionBuilder.KEY]}")
if value:
config[description[DescriptionBuilder.KEY]] = value
if config:
oss_config = storageConfig.value_of(config)
server_config.oss_config(oss_config)
# 加载对象存储配置
OSSConfigManager().load_config(oss_config)
start_config = server_config.build()
# 参数检查
if not self._dbshow_parameter_check(start_config):
return
start_falsk(start_config)
def _dbshow_parameter_check(self, server_config) -> bool:
# (merge)和(msg_path,micro_path,media_path) 二选一
# if not args.merge_path and not (args.msg_path and args.micro_path and args.media_path):
# print("[-] 请输入数据库路径([merge_path] or [msg_path, micro_path, media_path]")
# return
# 目前仅能支持merge database
if not args.merge_path:
if not server_config.merge_path:
print("[-] 请输入数据库路径([merge_path]")
return
return False
# 从命令行参数获取值
merge_path = args.merge_path
online = args.online
if not os.path.exists(merge_path):
if not AttachmentContext.exists(server_config.merge_path):
print("[-] 输入数据库路径不存在")
return
start_falsk(merge_path=merge_path, wx_path=args.wx_path, key="", my_wxid=args.my_wxid, online=online)
return False
return True
class MainExportChatRecords(BaseSubMainClass):
@ -338,7 +377,13 @@ class MainUi(BaseSubMainClass):
debug = args.debug
isopenBrowser = args.isOpenBrowser
start_falsk(port=port, online=online, debug=debug, isopenBrowser=isopenBrowser)
server_config = ServerConfig.builder()
server_config.is_open_browser(isopenBrowser)
server_config.debug(debug)
server_config.port(port)
server_config.online(online)
start_falsk(server_config.build())
class MainApi(BaseSubMainClass):
@ -359,7 +404,13 @@ class MainApi(BaseSubMainClass):
port = args.port
debug = args.debug
start_falsk(port=port, online=online, debug=debug, isopenBrowser=False)
server_config = ServerConfig.builder()
server_config.debug(debug)
server_config.port(port)
server_config.is_open_browser(False)
server_config.online(online)
start_falsk(server_config.build())
def console_run():

View File

@ -0,0 +1,4 @@
if __name__ == '__main__':
pass

View File

@ -0,0 +1,58 @@
from pywxdump.common.config.oss_config.storage_config import StorageConfig, TYPE_KEY, DescriptionBuilder
from pywxdump.common.config.oss_config.storage_config_factory import StorageConfigFactory
S3 = "s3"
ENDPOINT_URL = "endpoint_url"
SECRET_KEY = "secret_key"
ACCESS_KEY = "access_key"
@StorageConfigFactory.register(S3)
class S3Config(StorageConfig):
def __init__(self, access_key, secret_key, endpoint_url):
self.access_key = access_key
self.secret_key = secret_key
self.endpoint_url = endpoint_url
@classmethod
def type(cls) -> str:
return S3
@classmethod
def describe(cls):
builder = DescriptionBuilder()
builder.add_description(ENDPOINT_URL, "https://cos.<your-regin>.myqcloud.com", ENDPOINT_URL)
builder.add_description(ACCESS_KEY, "腾讯云的SecretId", ACCESS_KEY)
builder.add_description(SECRET_KEY, "腾讯云的SecretKey", SECRET_KEY)
return builder.build()
def get_config(self):
return {
TYPE_KEY: S3,
ACCESS_KEY: self.access_key,
SECRET_KEY: self.secret_key,
ENDPOINT_URL: self.endpoint_url
}
def validate_config(self):
if not self.access_key or not self.secret_key or not self.endpoint_url:
raise ValueError("S3 configuration is not valid")
@classmethod
def value_of(cls, config: dict):
access_key = config.get(ACCESS_KEY)
secret_key = config.get(SECRET_KEY)
endpoint_url = config.get(ENDPOINT_URL)
s3_config = cls(access_key, secret_key, endpoint_url)
s3_config.validate_config()
return s3_config
@classmethod
def isSupported(cls, path: str) -> bool:
if not path:
return False
if path.startswith(f"{S3}://"):
return True

View File

@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
TYPE_KEY = "oss_type_key"
class StorageConfig(ABC):
@classmethod
@abstractmethod
def type(cls):
"""返回实现类的类型"""
pass
@classmethod
@abstractmethod
def describe(cls):
"""返回对象存储的描述"""
pass
@abstractmethod
def get_config(self):
"""返回存储配置的字典"""
pass
@abstractmethod
def validate_config(self):
"""验证配置是否合法"""
pass
@classmethod
@abstractmethod
def value_of(cls, config: dict):
pass
@classmethod
@abstractmethod
def isSupported(cls, path: str) -> bool:
pass
class DescriptionBuilder:
LABEL = "label"
PLACEHOLDER = "placeholder"
KEY = "key"
def __init__(self):
self.description = []
def add_description(self, label, placeholder, key):
self.description.append({
self.LABEL: label,
self.PLACEHOLDER: placeholder,
self.KEY: key
})
return self
def build(self):
return self.description

View File

@ -0,0 +1,25 @@
import json
from abc import ABC
from pywxdump.common.config.oss_config.storage_config import TYPE_KEY, StorageConfig
class StorageConfigFactory(ABC):
registry = {}
@classmethod
def register(cls, type_name):
def inner_wrapper(subclass):
cls.registry[type_name] = subclass
return subclass
return inner_wrapper
@staticmethod
def create(json_str) -> StorageConfig:
config_dict = json.loads(json_str)
config_type = config_dict.get(TYPE_KEY)
subclass = StorageConfigFactory.registry.get(config_type)
if subclass is None:
raise ValueError(f'Unknown config type: {config_type}')
return subclass.value_of(config_dict)

View File

@ -0,0 +1,39 @@
from typing import Type
from pywxdump.common.config.oss_config.storage_config import StorageConfig
from pywxdump.file.ConfigurableAttachment import ConfigurableAttachment
def singleton(cls):
instances = {}
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return get_instance
@singleton
class OSSConfigManager:
def __init__(self):
self._config_instances = {}
self._attachment_instance = {}
def load_config(self, config: StorageConfig):
config.validate_config()
self._config_instances[config.type()] = config
# 清除旧的实例
self._attachment_instance[config.type()] = None
def get_config(self, config_type: str) -> StorageConfig:
return self._config_instances.get(config_type)
def get_attachment(self, config_type: str, instance_class: Type[ConfigurableAttachment]):
if config_type not in self._config_instances:
raise ValueError(f"Config not found: {config_type}")
if not self._attachment_instance[config_type]:
config = self._config_instances[config_type]
self._attachment_instance[config_type] = instance_class.load_config(config)
return self._attachment_instance[config_type]

View File

@ -0,0 +1,97 @@
import json
from dataclasses import dataclass
from pywxdump.common.config.oss_config.storage_config import StorageConfig
@dataclass
class ServerConfig:
"""
:param merge_path: 合并后的数据库路径 默认""
:param wx_path: 微信文件夹的路径用于显示图片 默认""
:param key: 密钥 默认""
:param my_wxid: 微信账号(本人微信id) 默认""
:param port: 端口号 默认5000
:param online: 是否在线查看(局域网查看) 默认 False
:param debug: 是否开启debug模式 默认 False
:param is_open_browser: 是否自动打开浏览器 默认 True
:param oss_config: 对象存储配置 默认 None
"""
merge_path: str = ""
wx_path: str = ""
key: str = ""
my_wxid: str = ""
port: int = 5000
online: bool = False
debug: bool = False
is_open_browser: bool = True
oss_config: dict = None
@classmethod
def builder(cls):
return ServerConfig.Builder()
class Builder:
def __init__(self):
self._merge_path = ""
self._wx_path = ""
self._key = ""
self._my_wxid = ""
self._port = 5000
self._online = False
self._debug = False
self._is_open_browser = True
self._oss_config = None
def merge_path(self, merge_path: str):
self._merge_path = merge_path
return self
def wx_path(self, wx_path: str):
self._wx_path = wx_path
return self
def key(self, key: str):
self._key = key
return self
def my_wxid(self, my_wxid: str):
self._my_wxid = my_wxid
return self
def port(self, port: int):
self._port = port
return self
def online(self, online: bool):
self._online = online
return self
def debug(self, debug: bool):
self._debug = debug
return self
def is_open_browser(self, is_open_browser: bool):
self._is_open_browser = is_open_browser
return self
def oss_config(self, oss_config: StorageConfig):
oss_config.validate_config()
self._oss_config = oss_config.get_config()
return self
def build(self):
return ServerConfig(
merge_path=self._merge_path,
wx_path=self._wx_path,
key=self._key,
my_wxid=self._my_wxid,
port=self._port,
online=self._online,
debug=self._debug,
is_open_browser=self._is_open_browser,
oss_config=self._oss_config
)
def oss_config_to_json(self) -> str:
return json.dumps(self.oss_config) if self.oss_config else None

View File

@ -38,6 +38,7 @@ class DatabaseBase:
if not AttachmentContext.isLocalPath(db_path):
temp_dir = tempfile.gettempdir()
local_path = os.path.join(temp_dir, f"{uuid.uuid1()}.db")
logging.info(f"下载文件到本地: {db_path} -> {local_path}")
AttachmentContext.download_file(db_path, local_path)
else:
local_path = db_path
@ -88,7 +89,23 @@ class DatabaseBase:
self._db_connection = None
if not AttachmentContext.isLocalPath(self._db_path):
# 删除tmp目录下的db文件
self.clearTmpDb()
self._clearTmpDb()
@classmethod
def terminate_connection(cls, db_path: str):
"""
关闭数据库连接
:param db_path: 数据库路径
"""
if db_path in cls._connection_pool and cls._connection_pool[db_path]:
cls._connection_pool[db_path].close()
logging.info(f"关闭数据库 - {db_path}")
cls._connection_pool[db_path] = None
if not AttachmentContext.isLocalPath(db_path):
# 删除tmp目录下的db文件
cls._clearTmpDb()
def close_all_connection(self):
for db_path in self._connection_pool:
@ -97,18 +114,18 @@ class DatabaseBase:
logging.info(f"关闭数据库 - {db_path}")
self._connection_pool[db_path] = None
# 删除tmp目录下的db文件
self.clearTmpDb()
def clearTmpDb(self):
# 清理 tmp目录下.db文件
self._clearTmpDb()
@staticmethod
def _clearTmpDb():
# 清理 tmp目录下.db文件如果db文件存储在对象存储中使用时需要下载到tmp目录中且文件名是由uuid生成为了避免文件过多需要清理
temp_dir = tempfile.gettempdir()
db_files = glob.glob(os.path.join(temp_dir, '*.db'))
for db_file in db_files:
try:
os.remove(db_file)
print(f"Deleted: {db_file}")
except Exception as e:
print(f"Error deleting {db_file}: {e}")
logging.error(f"Error deleting {db_file}: {e}")
def show__singleton_instances(self):
print(self._singleton_instances)

123
pywxdump/file/Attachment.py Normal file
View File

@ -0,0 +1,123 @@
from typing import Protocol, IO
# 基类
class Attachment(Protocol):
"""
附件处理协议类定义了附件处理的基本接口
"""
def exists(self, path: str) -> bool:
"""
检查文件或目录是否存在
参数:
path (str): 文件或目录路径
返回:
bool: 如果存在返回True否则返回False
"""
pass
def makedirs(self, path: str) -> bool:
"""
创建目录包括所有中间目录
参数:
path (str): 目录路径
返回:
bool: 总是返回True
"""
pass
def open(self, path: str, mode: str) -> IO:
"""
打开一个文件并返回文件对象
参数:
path (str): 文件路径
mode (str): 打开文件的模式
返回:
IO: 文件对象
"""
pass
def remove(self, path: str) -> bool:
"""
删除文件
参数:
path (str): 文件路径
返回:
bool: 是否删除成功
"""
pass
def isdir(self, path: str) -> bool:
"""
判断是否为目录
参数:
s3_url (str): 文件路径
返回:
bool: 是否为目录
"""
pass
@classmethod
def join(cls, path: str, *paths: str) -> str:
"""
连接一个或多个路径组件
参数:
path (str): 第一个路径组件
*paths (str): 其他路径组件
返回:
str: 连接后的路径
"""
pass
@classmethod
def dirname(cls, path: str) -> str:
"""
获取路径的目录名
参数:
path (str): 文件路径
返回:
str: 目录名
"""
pass
@classmethod
def basename(cls, path: str) -> str:
"""
获取路径的基本名文件名
参数:
path (str): 文件路径
返回:
str: 基本名文件名
"""
pass
def getsize(self, path) -> int:
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
pass

View File

@ -1,26 +0,0 @@
from typing import Protocol, IO
# 基类
class Attachment(Protocol):
def exists(self, path) -> bool:
pass
def makedirs(self, path) -> bool:
pass
def open(self, path, param) -> IO:
pass
@classmethod
def join(cls, __a: str, *paths: str) -> str:
pass
@classmethod
def dirname(cls, path: str) -> str:
pass
@classmethod
def basename(cls, path: str) -> str:
pass

View File

@ -1,41 +1,108 @@
import os
from datetime import datetime
from typing import AnyStr, BinaryIO, Callable, Union, IO
from typing import AnyStr, Callable, Union, IO
from flask import send_file, Response
from pywxdump.file.AttachmentAbstract import Attachment
from pywxdump.common.config.oss_config.s3_config import S3Config
from pywxdump.common.config.oss_config_manager import OSSConfigManager
from pywxdump.file.Attachment import Attachment
from pywxdump.file.LocalAttachment import LocalAttachment
from pywxdump.file.S3Attachment import S3Attachment
def determine_strategy(file_path: str) -> Attachment:
if file_path.startswith("s3://"):
return S3Attachment()
"""
根据文件路径确定使用的附件策略本地或S3
参数:
file_path (str): 文件路径
返回:
Attachment: 返回对应的附件策略类实例
"""
if file_path.startswith(f"s3://"):
return OSSConfigManager().get_attachment("s3", S3Attachment)
else:
return LocalAttachment()
def exists(path: str) -> bool:
"""
检查文件或目录是否存在
参数:
path (str): 文件或目录路径
返回:
bool: 如果存在返回True否则返回False
"""
return determine_strategy(path).exists(path)
def open_file(path: str, mode: str) -> IO:
"""
打开一个文件并返回文件对象
参数:
path (str): 文件路径
mode (str): 打开文件的模式
返回:
IO: 文件对象
"""
return determine_strategy(path).open(path, mode)
def makedirs(path: str) -> bool:
"""
创建目录包括所有中间目录
参数:
path (str): 目录路径
返回:
bool: 总是返回True
"""
return determine_strategy(path).makedirs(path)
def join(__a: str, *paths: str) -> str:
return determine_strategy(__a).join(__a, *paths)
def join(path: str, *paths: str) -> str:
"""
连接一个或多个路径组件
参数:
path (str): 第一个路径组件
*paths (str): 其他路径组件
返回:
str: 连接后的路径
"""
return determine_strategy(path).join(path, *paths)
def dirname(path: str) -> str:
"""
获取路径的目录名
参数:
path (str): 文件路径
返回:
str: 目录名
"""
return determine_strategy(path).dirname(path)
def basename(path: str) -> str:
"""
获取路径的基本名文件名
参数:
path (str): 文件路径
返回:
str: 基本名文件名
"""
return determine_strategy(path).basename(path)
@ -49,6 +116,22 @@ def send_attachment(
last_modified: Union[datetime, int, float, None] = None,
max_age: Union[None, int, Callable[[Union[str, None]], Union[int, None]]] = None,
) -> Response:
"""
发送附件文件
参数:
path_or_file (Union[os.PathLike[AnyStr], str]): 文件路径或文件对象
mimetype (Union[str, None]): 文件的MIME类型
as_attachment (bool): 是否作为附件下载
download_name (Union[str, None]): 下载时的文件名
conditional (bool): 是否使用条件请求
etag (Union[bool, str]): ETag值
last_modified (Union[datetime, int, float, None]): 最后修改时间
max_age (Union[None, int, Callable[[Union[str, None]], Union[int, None]]]): 缓存最大时间
返回:
Response: Flask的响应对象
"""
file_io = open_file(path_or_file, "rb")
# 如果没有提供 download_name 或 mimetype则从 path_or_file 中获取文件名和 MIME 类型
@ -61,6 +144,16 @@ def send_attachment(
def download_file(db_path, local_path):
"""
从db_path下载文件到local_path
参数:
db_path (str): 数据库文件路径
local_path (str): 本地文件路径
返回:
str: 本地文件路径
"""
with open(local_path, 'wb') as f:
with open_file(db_path, 'rb') as r:
f.write(r.read())
@ -68,5 +161,28 @@ def download_file(db_path, local_path):
def isLocalPath(path: str) -> bool:
return isinstance(determine_strategy(path), LocalAttachment)
"""
判断路径是否为本地路径
参数:
path (str): 文件或目录路径
返回:
bool: 如果是本地路径返回True否则返回False
"""
strategy = determine_strategy(path)
return isinstance(strategy, type(LocalAttachment()))
def getsize(path: str):
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
return determine_strategy(path).getsize(path)

View File

@ -0,0 +1,12 @@
from abc import ABC, abstractmethod
from pywxdump.common.config.oss_config import storage_config
from pywxdump.file.Attachment import Attachment
class ConfigurableAttachment(ABC, Attachment):
@classmethod
@abstractmethod
def load_config(cls, config: storage_config) -> Attachment:
"""设置配置"""
pass

View File

@ -3,36 +3,138 @@ import os
import sys
from typing import IO
from pywxdump.file.Attachment import Attachment
class LocalAttachment:
def singleton(cls):
instances = {}
def create_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]
return create_instance
@singleton
class LocalAttachment(Attachment):
def open(self, path, mode) -> IO:
"""
打开一个文件并返回文件对象
参数:
path (str): 文件路径
mode (str): 打开文件的模式
返回:
IO: 文件对象
"""
path = self.dealLocalPath(path)
return open(path, mode)
def remove(self, path: str) -> bool:
"""
删除文件
参数:
path (str): 文件路径
返回:
bool: 是否删除成功
"""
path = self.dealLocalPath(path)
if not self.exists(path):
raise FileNotFoundError(f"File not found: {path}")
if self.isdir(path):
raise ValueError(f"Path is not a file: {path}")
os.remove(path)
return True
def exists(self, path) -> bool:
"""
检查文件或目录是否存在
参数:
path (str): 文件或目录路径
返回:
bool: 如果存在返回True否则返回False
"""
path = self.dealLocalPath(path)
return os.path.exists(path)
def makedirs(self, path) -> bool:
"""
创建目录包括所有中间目录
参数:
path (str): 目录路径
返回:
bool: 总是返回True
"""
path = self.dealLocalPath(path)
os.makedirs(path)
return True
@classmethod
def join(cls, __a: str, *paths: str) -> str:
return os.path.join(__a, *paths)
def join(cls, path: str, *paths: str) -> str:
"""
连接一个或多个路径组件
参数:
path (str): 第一个路径组件
*paths (str): 其他路径组件
返回:
str: 连接后的路径
"""
# 使用os.path.join连接路径
return os.path.join(path, *paths)
@classmethod
def dirname(cls, path: str) -> str:
"""
获取路径的目录名
参数:
path (str): 文件路径
返回:
str: 目录名
"""
# 获取路径的目录名
return os.path.dirname(path)
@classmethod
def basename(cls, path: str) -> str:
"""
获取路径的基本名文件名
参数:
path (str): 文件路径
返回:
str: 基本名文件名
"""
# 获取路径的基本名
return os.path.basename(path)
def dealLocalPath(self, path: str) -> str:
# 获取当前系统的地址分隔符
"""
处理本地路径替换路径中的分隔符并根据操作系统进行特殊处理
参数:
path (str): 文件路径
返回:
str: 处理后的路径
"""
# 获取当前系统的路径分隔符
# 将path中的 / 替换为当前系统的分隔符
path = path.replace('/', os.sep)
if sys.platform == "win32":
@ -44,3 +146,52 @@ class LocalAttachment:
return path
else:
return path
def isdir(self, path: str) -> bool:
"""
判断是否为目录
参数:
path (str): 文件路径
返回:
bool: 是否为目录
"""
# 判断路径是否为目录
return os.path.isdir(path)
def getsize(self, path) -> int:
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
if not self.exists(path):
raise FileNotFoundError(f"File not found: {path}")
if os.path.isfile(path):
return os.path.getsize(path)
else:
return self._get_dir_size(path)
def _get_dir_size(self, path):
"""
计算目录大小
参数:
path (str): 目录路径
返回:
int: 目录大小
"""
total_size = 0
for firePath, surnames, filenames in os.walk(path):
for f in filenames:
fp = self.join(firePath, f)
total_size += os.path.getsize(fp)
return total_size

View File

@ -1,34 +1,50 @@
# 对象存储文件处理类(示例:假设是 AWS S3
import os
from typing import IO
from urllib.parse import urlparse
from urllib.parse import urlparse, urljoin
from botocore.exceptions import ClientError
from smart_open import open
import boto3
from botocore.client import Config
class S3Attachment:
from pywxdump.common.config.oss_config import storage_config
from pywxdump.file.Attachment import Attachment
from pywxdump.file.ConfigurableAttachment import ConfigurableAttachment
def __init__(self):
# 腾讯云 COS 配置
self.cos_endpoint = "https://cos.<your-region>.myqcloud.com" # 替换 <your-region> 为你的 COS 区域,例如 ap-shanghai
self.access_key_id = "SecretId" # 替换为你的腾讯云 SecretId
self.secret_access_key = "SecretKey" # 替换为你的腾讯云 SecretKey
class S3Attachment(ConfigurableAttachment):
def __init__(self, s3_config: storage_config):
# S3 配置
self.s3_config = s3_config
# 校验配置
s3_config.validate_config()
# 创建 S3 客户端
self.s3_client = boto3.client(
's3',
endpoint_url=self.cos_endpoint,
aws_access_key_id=self.access_key_id,
aws_secret_access_key=self.secret_access_key,
endpoint_url=s3_config.endpoint_url,
aws_access_key_id=s3_config.access_key,
aws_secret_access_key=s3_config.secret_key,
config=Config(s3={"addressing_style": "virtual", "signature_version": 's3v4'})
)
def exists(self, path) -> bool:
bucket_name, path = self.dealS3Url(path)
# 检查是否为目录
if path.endswith('/'):
@classmethod
def load_config(cls, config: storage_config) -> Attachment:
return cls(config)
def exists(self, s3_url) -> bool:
"""
检查对象是否存在
参数:
s3_url (str): 对象路径
返回:
bool: 是否存在
"""
bucket_name, path = self.dealS3Url(s3_url)
# 尝试列出该路径下的对象
try:
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
@ -39,51 +55,111 @@ class S3Attachment:
except ClientError as e:
print(f"Error: {e}")
return False
else:
# 检查是否为文件
try:
self.s3_client.head_object(Bucket=bucket_name, Key=path)
return True
except ClientError as e:
if e.response['Error']['Code'] == '404':
return False
else:
print(f"Error: {e}")
return False
def makedirs(self, path) -> bool:
if not self.exists(path):
bucket_name, path = self.dealS3Url(path)
def makedirs(self, s3_url) -> bool:
"""
创建目录
参数:
s3_url (str): 目录路径
返回:
bool: 是否创建成功
"""
if not self.exists(s3_url):
bucket_name, path = self.dealS3Url(s3_url)
self.s3_client.put_object(Bucket=bucket_name, Key=f'{path}/')
return True
def open(self, path, mode) -> IO:
self.dealS3Url(path)
return open(uri=path, mode=mode, transport_params={'client': self.s3_client})
def open(self, s3_url, mode) -> IO:
"""
打开文件
参数:
s3_url (str): 文件路径
mode (str): 打开模式
返回:
IO: 文件对象
"""
return open(uri=s3_url, mode=mode, transport_params={'client': self.s3_client})
def remove(self, s3_url: str) -> bool:
"""
删除文件
参数:
s3_url (str): 文件路径
返回:
bool: 是否删除成功
"""
if not self.exists(s3_url):
raise FileNotFoundError(f"File not found: {s3_url}")
if self.isdir(s3_url):
raise ValueError(f"Path is not a file: {s3_url}")
bucket_name, path = self.dealS3Url(s3_url)
self.s3_client.delete_object(Bucket=bucket_name, Key=path)
return True
@classmethod
def join(cls, __a: str, *paths: str) -> str:
return os.path.join(__a, *paths)
def join(cls, s3_url: str, *paths: str) -> str:
"""
连接路径
参数:
s3_url (str): 路径
*paths (str): 路径
返回:
str: 连接后的路径
"""
# 使用os.path.join连接路径
path = os.path.join(s3_url, *paths)
# 将所有反斜杠替换为正斜杠
return path.replace('\\', '/')
@classmethod
def dirname(cls, path: str) -> str:
return os.path.dirname(path)
def dirname(cls, s3_url: str) -> str:
"""
返回路径的目录部分
参数:
s3_url (str): 路径
返回:
str: 路径的目录部分
"""
return os.path.dirname(s3_url)
@classmethod
def basename(cls, path: str) -> str:
return os.path.basename(path)
def basename(cls, s3_url: str) -> str:
"""
返回路径的最后一个元素
def dealS3Url(self, path: str) -> object:
参数:
s3_url (str): 路径
返回:
str: 路径的最后一个元素
"""
return os.path.basename(s3_url)
def dealS3Url(self, s3_url: str) -> object:
"""
解析 S3 URL 并返回存储桶名称和路径
参数:
path (str): S3 URL
s3_url (str): S3 URL
返回:
tuple: 包含存储桶名称和路径的元组
"""
parsed_url = urlparse(path)
parsed_url = urlparse(s3_url)
# 确保URL是S3 URL
if parsed_url.scheme != 's3':
@ -94,3 +170,75 @@ class S3Attachment:
return bucket_name, s3_path
def isdir(self, s3_url: str) -> bool:
"""
判断是否为目录
参数:
s3_url (str): 文件路径
返回:
bool: 是否为目录
"""
# 确保目录路径以'/'结尾
if not s3_url.endswith('/'):
s3_url += '/'
bucket_name, path = self.dealS3Url(s3_url)
# 列出以该 key 为前缀的对象
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path, MaxKeys=1)
if 'Contents' in response:
# 存在对象,判断是否为目录
if response['Contents'][0]['Key'] == path or not path.endswith('/'):
return False
else:
return True
else:
return False
def getsize(self, s3_url) -> int:
"""
获取文件大小
参数:
path (str): 文件路径
返回:
int: 文件大小
"""
if not self.exists(s3_url):
raise FileNotFoundError(f"File not found: {s3_url}")
if self.isdir(s3_url):
return self._get_size_of_directory(s3_url)
else:
bucket_name, path = self.dealS3Url(s3_url)
response = self.s3_client.head_object(Bucket=bucket_name, Key=path)
return response['ContentLength']
def _get_size_of_directory(self, s3_url):
"""
获取目录大小
参数:
s3_url (str): 目录路径
返回:
int: 目录大小
"""
bucket_name, path = self.dealS3Url(s3_url)
total_size = 0
# 确保目录路径以'/'结尾
if not path.endswith('/'):
path += '/'
# 列出指定目录中的对象
response = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=path)
if 'Contents' in response:
for obj in response['Contents']:
total_size += obj['Size']
return total_size

View File

@ -0,0 +1,4 @@
if __name__ == '__main__':
pass

View File

@ -10,9 +10,11 @@ import subprocess
import sys
import time
from pywxdump.common.config.oss_config.s3_config import S3Config
from pywxdump.common.config.server_config import ServerConfig
def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online=False, debug=False,
isopenBrowser=True):
def start_falsk(server_config: ServerConfig):
"""
启动flask
:param merge_path: 合并后的数据库路径
@ -38,13 +40,14 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online
import logging
# 检查端口是否被占用
if online:
if server_config.online:
host = '0.0.0.0'
else:
host = "127.0.0.1"
app = Flask(__name__, template_folder='./ui/web', static_folder='./ui/web/assets/', static_url_path='/assets/')
app.config['ENV'] = 'development'
app.config['DEBUG'] = True
# 设置超时时间为 1000 秒
app.config['TIMEOUT'] = 1000
app.secret_key = 'secret_key'
@ -67,18 +70,20 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online
g.tmp_path = tmp_path # 临时文件夹,用于存放图片等
g.sf = session_file # 用于存放各种基础信息
if merge_path: save_session(session_file, "test", "merge_path", merge_path)
if wx_path: save_session(session_file, "test", "wx_path", wx_path)
if key: save_session(session_file, "test", "key", key)
if my_wxid: save_session(session_file, "test", "my_wxid", my_wxid)
wxid = server_config.my_wxid if server_config.my_wxid else "test"
if server_config.merge_path: save_session(session_file, wxid, "merge_path", server_config.merge_path)
if server_config.wx_path: save_session(session_file, wxid, "wx_path", server_config.wx_path)
if server_config.key: save_session(session_file, wxid, "key", server_config.key)
if server_config.my_wxid: save_session(session_file, wxid, "my_wxid", server_config.my_wxid)
if server_config.oss_config: save_session(session_file, wxid, "oss_config", server_config.oss_config_to_json())
if not os.path.exists(session_file):
save_session(session_file, "test", "last", my_wxid)
save_session(session_file, wxid, "last", server_config.my_wxid)
app.register_blueprint(api)
if isopenBrowser:
if server_config.is_open_browser:
try:
# 自动打开浏览器
url = f"http://127.0.0.1:{port}/"
url = f"http://127.0.0.1:{server_config.port}/"
# 根据操作系统使用不同的命令打开默认浏览器
if sys.platform.startswith('darwin'): # macOS
subprocess.call(['open', url])
@ -100,13 +105,13 @@ def start_falsk(merge_path="", wx_path="", key="", my_wxid="", port=5000, online
return True
return False
if is_port_in_use(host, port):
print(f"Port {port} is already in use. Choose a different port.")
if is_port_in_use(host, server_config.port):
print(f"Port {server_config.port} is already in use. Choose a different port.")
input("Press Enter to exit...")
else:
time.sleep(1)
print("[+] 请使用浏览器访问 http://127.0.0.1:5000/ 查看聊天记录")
app.run(host=host, port=port, debug=debug)
app.run(host=host, port=server_config.port, debug=server_config.debug)
if __name__ == '__main__':
@ -114,6 +119,16 @@ if __name__ == '__main__':
wx_path = r"****"
my_wxid = "****"
server_config = ServerConfig.builder()
server_config.merge_path("s3://*********-1256220500/*********/merge_all.db")
server_config.wx_path("s3://*********-1256220500/*********")
server_config.my_wxid("test")
server_config.port(9000)
server_config.online(True)
server_config.is_open_browser(False)
s3Config = S3Config("AKIDaAjA*********I1kR4gFdv67v", "wlT2ldSBk*********Qh4fEev47",
"https://cos.ap-nanjing.myqcloud.com")
server_config.oss_config(s3Config)
start_falsk(server_config.build())
start_falsk(merge_path=merge_path, wx_path=wx_path, my_wxid=my_wxid,
port=5000, online=False, debug=False, isopenBrowser=False)

View File

@ -395,7 +395,7 @@ def get_wechat_db(require_list: Union[List[str], str] = "all", msg_dir: str = No
def get_core_db(wx_path: str, db_type: list = None) -> [str]:
"""
获取聊天消息核心数据库路径
:param wx_path: 微信文件夹路径 egC:\*****\WeChat Files\wxid*******
:param wx_path: 微信文件夹路径 egC:\\*****\\WeChat Files\\wxid*******
:param db_type: 数据库类型 eg: ["MSG", "MediaMSG", "MicroMsg"]三个中选择一个或多个
:return: 返回数据库路径 eg:["",""]
"""

View File

@ -304,7 +304,7 @@ def decrypt_merge(wx_path, key, outpath="", CreateTime: int = 0, endCreateTime:
bool, str):
"""
解密合并数据库 msg.db, microMsg.db, media.db,注意会删除原数据库
:param wx_path: 微信路径 eg: C:\*******\WeChat Files\wxid_*********
:param wx_path: 微信路径 eg: C:\\*******\\WeChat Files\\wxid_*********
:param key: 解密密钥
:return: (true,解密后的数据库路径) or (false,错误信息)
"""
@ -430,7 +430,7 @@ def all_merge_real_time_db(key, wx_path, merge_path):
这是全量合并会有可能产生重复数据需要自行去重
:param key: 解密密钥
:param wx_path: 微信路径
:param merge_path: 合并后的数据库路径 eg: C:\*******\WeChat Files\wxid_*********\merge.db
:param merge_path: 合并后的数据库路径 eg: C:\\*******\\WeChat Files\\wxid_*********\\merge.db
:return:
"""
if not merge_path or not key or not wx_path or not wx_path:

View File

@ -17,3 +17,4 @@ flask_cors
pandas
smart_open[s3]
boto3
botocore

View File

@ -0,0 +1,63 @@
import os
import unittest
from unittest.mock import patch
from pywxdump.file.LocalAttachment import LocalAttachment
import tempfile
class TestLocalAttachment(unittest.TestCase):
def setUp(self):
self.attachment = LocalAttachment()
self.test_file_path = self.attachment.join(tempfile.gettempdir(),"test.txt")
self.test_dir_path = self.attachment.join(tempfile.gettempdir(),"test_dir")
def tearDown(self):
if os.path.exists(self.test_file_path):
os.remove(self.test_file_path)
if os.path.exists(self.test_dir_path):
for dirpath, dirnames, filenames in os.walk(self.test_dir_path):
for f in filenames:
fp = os.path.join(dirpath, f)
os.remove(fp)
os.rmdir(self.test_dir_path)
def test_remove_existing_file(self):
with open(self.test_file_path, 'w') as f:
f.write("test")
self.assertTrue(self.attachment.remove(self.test_file_path))
self.assertFalse(os.path.exists(self.test_file_path))
def test_remove_non_existing_file(self):
with self.assertRaises(FileNotFoundError):
self.attachment.remove(self.test_file_path)
@patch('os.remove')
def test_remove_os_error(self, mock_remove):
mock_remove.side_effect = OSError
with self.assertRaises(OSError):
self.attachment.remove(self.test_file_path)
def test_getsize_existing_file(self):
# 清理测试环境
self.tearDown()
with open(self.test_file_path, "w") as f:
f.write("Hello, World!")
self.assertEqual(self.attachment.getsize(self.test_file_path), 13)
def test_getsize_non_existing_file(self):
with self.assertRaises(FileNotFoundError):
self.attachment.getsize('non_existing_file')
def test_getsize_existing_folder(self):
# 清理测试环境
self.tearDown()
os.mkdir(self.test_dir_path)
with open(os.path.join(self.test_dir_path, "file1.txt"), "w") as f:
f.write("Hello, World!")
with open(os.path.join(self.test_dir_path, "file2.txt"), "w") as f:
f.write("Hello, World!")
self.assertEqual(self.attachment.getsize(self.test_dir_path), 26)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,71 @@
import unittest
from unittest.mock import patch, MagicMock
from botocore.exceptions import ClientError
import boto3
from pywxdump.file.S3Attachment import S3Attachment
class TestS3Attachment(unittest.TestCase):
@patch('boto3.client')
def setUp(self, mock_client):
self.mock_client = MagicMock()
mock_client.return_value = self.mock_client
s3_config = MagicMock()
self.attachment = S3Attachment(s3_config)
self.test_s3_url = "s3://test_bucket/test_file"
self.test_s3_dir = "s3://test_bucket/test_folder/"
@patch.object(S3Attachment, 'exists')
@patch.object(S3Attachment, 'isFolder')
@patch('boto3.client')
def test_removal_of_existing_file(self, mock_client, mock_isFolder, mock_exists):
mock_exists.return_value = True
mock_isFolder.return_value = False
mock_client.return_value = MagicMock()
self.assertTrue(self.attachment.remove(self.test_s3_url))
@patch.object(S3Attachment, 'exists')
def test_removal_of_non_existing_file(self, mock_exists):
mock_exists.return_value = False
with self.assertRaises(FileNotFoundError):
self.attachment.remove(self.test_s3_url)
@patch.object(S3Attachment, 'exists')
@patch.object(S3Attachment, 'isFolder')
def test_removal_of_folder_instead_of_file(self, mock_isFolder, mock_exists):
mock_exists.return_value = True
mock_isFolder.return_value = True
with self.assertRaises(ValueError):
self.attachment.remove(self.test_s3_url)
@patch.object(S3Attachment, 'exists')
@patch.object(S3Attachment, 'isdir')
def test_removal_with_s3_error(self, mock_isdir, mock_exists):
mock_exists.return_value = True
mock_isdir.return_value = False
# 模拟 ClientError
error_response = {'Error': {'Code': 'InvalidRequest', 'Message': 'Some error message'}}
self.mock_client.delete_object.side_effect = ClientError(error_response, 'delete_object')
with self.assertRaises(ClientError):
self.attachment.remove(self.test_s3_url)
@patch.object(S3Attachment, 'exists')
def test_getsize_existing_file(self, mock_exists):
mock_exists.return_value = True
self.mock_client.head_object.return_value = {'ContentLength': 100}
self.assertEqual(self.attachment.getsize(self.test_s3_url), 100)
def test_getsize_non_existing_file(self):
with self.assertRaises(FileNotFoundError):
self.attachment.getsize('non_existing_file')
@patch.object(S3Attachment, 'isdir')
def test_getsize_existing_folder(self, mock_isdir):
mock_isdir.return_value = True
self.mock_client.list_objects_v2.return_value = {'Contents': [{'Size': 100}, {'Size': 300}]}
self.assertEqual(self.attachment.getsize(self.test_s3_dir), 400)
if __name__ == '__main__':
unittest.main()