why:某一层(称作 rebalance layer)需要用到一个 rebalance 操作,需要在某个层的梯度乘上一个 平衡矩阵(feed-forward 时不用)
how:拓展 torch.autograd 函数
forward()
– 自定义操作的feed-forward 代码部分。 参数个数随意任何 Python 类型都可以。Variable
会在调用前转换为Tensor
s ,它们的使用也会注册在图中。 Note that this logic won’t traverse(遍历) lists/dicts/any other data structures and will only consider Variables that are direct arguments to the call. 此处可以返回单个Tensor
output, or atuple
ofTensor
s 。backward()
– 梯度计算部分。它将被返回与 farward() 相对应的outputVariable
个数相同的梯度。对于不想回传的参数返回一个 None 即可。
实现历程,首先想到的是直接“篡改” rebalance layer 的梯度:
loss.backward() network.fc1.weight.grad *= rebalance_matrix optimizer.step()
这种套路简单粗暴,如果打印出来新的权重,你会发现 rebalance layer 的权重果然有变化(不是简单的学习率带来的变化,更多的是 rebalance_matrix 的影响 )。但是前面层都没有变!!点解???
原来 pytorch 在 .backward() 这一步已经把所有层参数的梯度都计算好了,单纯的覆盖某一层的梯度并不会把影响链式传递下去!
可是我们需要把这个 rebalance_matrix 的影响传播下去啊。
于是得手写一个自定义 op 出来。
class Rebalance_Op(Function): @staticmethod defforward(ctx, input, factors): ctx.save_for_backward(input, factors) #return 不能仅返回 input 否则 不会执行 backward 操作, returninput*1. @staticmethod defbackward(ctx, grad_output): input, factors = ctx.saved_variables grad_input = grad_factors = None if ctx.needs_input_grad[0]: grad_input = grad_output * factors return grad_input, None
打印出 梯度,会发现梯度完美更新。
建议测试网络:
class My_net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(4,5, bias=False) self.fc2 = nn.Linear(5,3, bias=False) self.mc = Rebalance_Op.apply self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): init.constant(m.weight, 1.) if m.bias is not None: m.bias.data.zero_() def forward(self,input): fc1 = self.fc1(input) out = self.fc2(fc1) out = self.mc(out, fact) # out = out * 6 return out #fix input&target for the debug convience target = Variable(torch.zeros(2, 3)) input = Variable(torch.ones(2, 4)) Variable(torch.from_numpy(np.arange(0,4).astype('float32'))) network = My_net()