在Easy-Wav2Lip项目中,我遇到了典型的设备不匹配问题。
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
它表明模型权重(weight)和输入数据(input)不在同一个设备上,一个在CPU,另一个在GPU。

🔧 问题排查

检查inference.py中哪里可能导致权重加载错误。

  1. 修改_load 函数
def _load(checkpoint_path):
    print(f"[DEBUG] 当前设备设置: {device}")
    print(f"[DEBUG] GPU ID: {gpu_id}")
    
    if device != "cpu":
        print(f"[DEBUG] 尝试加载到GPU/MPS设备")
        # 明确指定设备映射
        if device == 'cuda':
            checkpoint = torch.load(checkpoint_path, map_location='cuda')
        elif device == 'mps':
            checkpoint = torch.load(checkpoint_path, map_location='mps')
        else:
            checkpoint = torch.load(checkpoint_path)
    else:
        print(f"[DEBUG] 加载到CPU设备")
        checkpoint = torch.load(
            checkpoint_path, map_location=lambda storage, loc: storage
        )
    
    print(f"[DEBUG] 加载的checkpoint设备信息: {next(iter(checkpoint['state_dict'].values())).device if 'state_dict' in checkpoint else '未知'}")
    return checkpoint
  1. 修改 do_load 函数:
def do_load(checkpoint_path):
    global model, detector, detector_model
    
    print(f"[DEBUG] === 开始加载模型 ===")
    print(f"[DEBUG] 目标设备: {device}")
    
    model = load_model(checkpoint_path)
    
    # 添加模型设备检查
    print(f"[DEBUG] 主模型加载完成,检查设备:")
    if hasattr(model, 'parameters') and len(list(model.parameters())) > 0:
        first_param = next(model.parameters())
        print(f"[DEBUG] 模型参数设备: {first_param.device}")
    else:
        print(f"[DEBUG] 模型参数设备: 无法检测")
    
    detector = RetinaFace(
        gpu_id=gpu_id, model_path="checkpoints/mobilenet.pth", network="mobilenet"
    )
    detector_model = detector.model
    
    print(f"[DEBUG] === 模型加载完成 ===\n")
  1. 在 main 函数开始处添加设备信息:
def main():
    print(f"[SYSTEM] 最终使用的设备: {device}")
    print(f"[SYSTEM] CUDA可用: {torch.cuda.is_available()}")
    print(f"[SYSTEM] MPS可用: {torch.backends.mps.is_available() if hasattr(torch.backends, 'mps') else 'N/A'}")
    print(f"[SYSTEM] GPU ID: {gpu_id}")
    
    # 原有的main函数代码...

问题定位

运行代码,定位到函数do_load 。

(easy_wav) D:\work\easy-Wav2Lip\Easy-Wav2Lip>call run_loop.bat
opening GUI
Saving config
starting Easy-Wav2Lip...
Processing full.mp4 using playlist-file.wav for audio
imports loaded!
[DEBUG] === 开始加载模型 ===
[DEBUG] 目标设备: cuda
[DEBUG] 主模型加载完成,检查设备:
[DEBUG] 模型参数设备: cpu
[DEBUG] === 模型加载完成 ===

[SYSTEM] 最终使用的设备: cuda
[SYSTEM] CUDA可用: True
[SYSTEM] MPS可用: False
[SYSTEM] GPU ID: 0

解决办法

def do_load(checkpoint_path):
    global model, detector, detector_model
    
    # 获取当前设备配置
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    gpu_id = 0 if torch.cuda.is_available() else -1
    
    print(f"[DEBUG] 当前设备: {device}, GPU ID: {gpu_id}")
    
    # 修改_load函数以正确处理设备映射
    def _load(checkpoint_path):
        if device == 'cuda' and torch.cuda.is_available():
            map_location = f'cuda:{gpu_id}' if gpu_id >= 0 else 'cuda'
        elif device == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            map_location = 'mps'
        else:
            map_location = 'cpu'
        
        print(f"[DEBUG] 使用设备映射: {map_location}")
        return torch.load(checkpoint_path, map_location=map_location)
    
    # 加载主模型
    checkpoint = _load(checkpoint_path)
    model.load_state_dict(checkpoint)
    
    # 确保模型在正确的设备上
    if device == 'cuda' and torch.cuda.is_available():
        model = model.cuda(gpu_id if gpu_id >= 0 else None)
    elif device == 'mps' and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        model = model.to('mps')
    
    print(f"[DEBUG] 主模型设备: {next(model.parameters()).device}")
    
    return model

这个解决方案的关键点在于:

  • 动态设备检测​:自动识别可用的计算设备
  • 正确的map_location设置​:确保权重加载到目标设备
  • ​设备一致性检查​:验证模型和数据的设备一致性
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐