pytorch 实现自定义操作及反向传导

why:某一层(称作 rebalance layer)需要用到一个 rebalance 操作,需要在某个层的梯度乘上一个 平衡矩阵(feed-forward 时不用)

how:拓展 torch.autograd 函数

  • forward() – 自定义操作的feed-forward 代码部分。 参数个数随意任何 Python 类型都可以。 Variable 会在调用前转换为 Tensor s ,它们的使用也会注册在图中。 Note that this logic won’t traverse(遍历) lists/dicts/any other data structures and will only consider Variables that are direct arguments to the call. 此处可以返回单个 Tensor output, or a tuple of Tensor s 。
  • backward() – 梯度计算部分。它将被返回与 farward() 相对应的outputVariable个数相同的梯度。对于不想回传的参数返回一个 None 即可。

实现历程,首先想到的是直接“篡改” rebalance layer 的梯度:

loss.backward()
network.fc1.weight.grad *= rebalance_matrix 
optimizer.step()

这种套路简单粗暴,如果打印出来新的权重,你会发现 rebalance layer 的权重果然有变化(不是简单的学习率带来的变化,更多的是 rebalance_matrix 的影响 )。但是前面层都没有变!!点解???

原来 pytorch 在 .backward() 这一步已经把所有层参数的梯度都计算好了,单纯的覆盖某一层的梯度并不会把影响链式传递下去!

可是我们需要把这个 rebalance_matrix 的影响传播下去啊。

于是得手写一个自定义 op 出来。

class Rebalance_Op(Function):

    @staticmethod
    defforward(ctx, input, factors):
        ctx.save_for_backward(input, factors)
        #return 不能仅返回 input 否则 不会执行 backward 操作,
        returninput*1.

    @staticmethod
    defbackward(ctx, grad_output):
        input, factors = ctx.saved_variables
        grad_input = grad_factors = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output * factors
        return grad_input, None


打印出 梯度,会发现梯度完美更新。

建议测试网络:

class My_net(nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4,5, bias=False)
        self.fc2 = nn.Linear(5,3, bias=False)
        self.mc = Rebalance_Op.apply
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.constant(m.weight, 1.)
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self,input):
        fc1 = self.fc1(input)
        out = self.fc2(fc1)
        out = self.mc(out, fact)
        # out = out * 6
        return out

#fix input&target for the debug convience

target = Variable(torch.zeros(2, 3))
input = Variable(torch.ones(2, 4))

Variable(torch.from_numpy(np.arange(0,4).astype('float32')))
network = My_net()
Share this to:

发表评论

电子邮件地址不会被公开。 必填项已用*标注