pytorch项目导出torch.jit.script经验
一般是用torch.jit.script导出.pt文件使用分为script和save两个步骤,主要思路是借助大模型进行改错。pytorch mobile不支持任何complex的操作,所以说如果输入参数有complex,就分别输入实部和虚部,然后内部再拼接,比如:原生:改为:script部分:1. 分模块编译一定要分模块编译,比如说你要编译的model分为ABC三个模块,那么就如下操作:这样方便灵
一般是用torch.jit.script导出.pt文件使用
分为script和save两个步骤,主要思路是借助大模型进行改错。
预定义的坑:
pytorch mobile不支持任何complex的操作,所以说如果输入参数有complex,就分别输入实部和虚部,然后内部再拼接,比如:
原生:
def forward(
self,
input: torch.Tensor,
ilens: torch.Tensor,
lip_emb: torch.Tensor,
tpd: torch.Tensor,
states = None,
additional: Optional[Dict] = None,
) -> Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]:
batch = input
batch0 = batch.transpose(1, 2) # [B, M, T, F]
batch = torch.cat((batch0.real, batch0.imag), dim=1)
改为:
def forward(
self,
input_real: torch.Tensor,
input_imag: torch.Tensor,
ilens: torch.Tensor,
lip_emb: torch.Tensor,
tpd: torch.Tensor,
):
batch0_real = input_real.transpose(1, 2)
batch0_imag = input_imag.transpose(1, 2)
batch = torch.cat((batch0_real, batch0_imag), dim=1)
script部分:
1. 分模块编译
def try_script(name, module):
# 分模块编译
try:
torch.jit.script(module)
print(f"✅ {name} scripted successfully")
except Exception as e:
print(f"❌ {name} failed: {e}")
一定要分模块编译,比如说你要编译的model分为ABC三个模块,那么就如下操作:
try_script("A", model.A)
try_script("B", model.B)
try_script("C", model.C)
这样方便灵活的排查问题,找到了模块的问题还可以继续深入子模块进行排查。
2. 避免接口复用
因为jit.script要求静态的计算图,所以不能出现一个参数有多个属性的可能,比如说,在某个forward函数如下定义:
def forward(x, state=None)
这个接口兼容了流式推理和非流式推理,流式推理两个参数都会传入,非流式推理只会传入x,那么这个时候state的属性本身就是不固定的,这个对于jit.script来说是无法接受的,jit.trace或许可以,但是我没用过。
把错误粘贴给大模型,他可能会告诉你,给每个输入参数预先定义好输入类型,如下:
def forward(x: torch.Tensor, state: torch.Tensor=None)-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
经过笔者的实践,发现这样并不好用,Optional也可能会引入问题,最佳方法还是不要接口复用,流式一套,非流式一套,然后这样可以把所有定义类型全部去掉。
3. 其他细碎的点
有些上古时期的代码,做了兼容性设计,比如说torch老版本和新版本兼容,这个对jit.script是不可行的,必须完全静态,以及说不要出现输入但是未使用的参数(改接口复用就可能出现这个问题),代码中的assert也要去掉,都是潜在的坑。
save部分
这一块坑非常恶心,因为涉及到把python转成c++底层,所以python自己的traceback.print_exc()追踪不到问题,只会最后给你返回因为什么失败的,所以就不知道问题出在哪里,这就体现出分模块的重要性。
def try_save(name, module):
# 分模块编译
try:
mod = torch.jit.script(module)
mod.save(name)
print(f"✅ {name} saveed successfully")
except Exception as e:
print(f"❌ {name} failed: {e}")
traceback.print_exc()
def find_undefined_tensors(module, name="root", depth=0):
indent = " " * depth
try:
# 只 script + save,不访问 .strides()
scripted = torch.jit.script(module)
scripted.save(f"/tmp/debug_{name.replace('.', '_')}.pt")
# print(f"{indent}✅ {name} saved OK")
except Exception as e:
print(f"{indent}🔍 Trying to save: {name} ({type(module).__name__})")
print(f"{indent}❌ {name} save failed: {e}")
# 不 exit,继续找更深的模块
# 递归子模块
for child_name, child in module.named_children():
find_undefined_tensors(child, f"{name}.{child_name}", depth + 1)
当时遇到最多也是最恶心的错误是:
strides() called on an undefined Tensor
出现这个的原因大概率是,输入的参数默认值给了None,最好不要给默认值。
验证部分
通过了上述两个部分,就可以进行验证:
try:
print("Verifying exported model...")
loaded_model = torch.jit.load(args.output_file)
with torch.no_grad():
output = loaded_model(real_part, imag_part, example_ilens_cur, example_lip_cur, example_tpd_cur)
print(f"Verification successful! Output shape: {output.shape}")
except Exception as e:
print(f"Verification failed: {e}")
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)