pytorch bugs(Updating)

一. KeyError: ‘missing keys in state_dict:

修改了网络的输入层,将原来的 6 channel 输入变成了13 channel 输入层,那么预训练模型的参数第一层卷积层肯定是对不上的啊。于是乎就把第一层的卷积参数del 掉。如下:

checkpoint = torch.load(args.resume)
 if checkpoint['state_dict']['conv1.0.weight'].shape[1] == 6:
    del checkpoint['state_dict']['conv1.0.bias']
    del checkpoint['state_dict']['conv1.0.weight']

结果,可爱的bug接踵而至。详细原因解析链接。

https://discuss.pytorch.org/t/pytorch-pretrained-vgg19-keyerror/4140/3

但是并不能解决我的问题,于是手动改成如下代码:

if checkpoint['state_dict']['conv1.0.weight'].shape[1] == 6:
     checkpoint['state_dict']['conv1.0.bias'] =    model.state_dict()['conv1.0.bias']
     checkpoint['state_dict']['conv1.0.weight'] =  model.state_dict()['conv1.0.weight']

问题解决!

二.RuntimeError: bool value of Variable objects containing non-empty torch.cuda.ByteTensor is ambiguous

想记录所有不为 0 的 loss 到一个 list。代码如下:

val = (self.weights[i] * self.cretion( output[i], torch.max(target[i], 1)[1])) if torch.sum(target[i]) != 0 else None

于是就报错,想了一下,应该是类型不匹配,修改代码:

val = (self.weights[i] * self.cretion( output[i], torch.max(target[i], 1)[1])) if torch.sum(target[i]).cpu().data[0] != 0 else None
Share this to:

发表评论

电子邮件地址不会被公开。 必填项已用*标注