手动实现计算图
之前打过一次手动求导,但是学了pytorch之后发现自己并没有打backward之类的操作,不能偏导
计算图的思路都很显然,问题是要怎样实现才好
- 我们要定义一个类:点类(Node)
- 因为实际计算中可能会有多个图,所以我们还要定义一个图类(Graph)
- 为了长得像我们再定义一个tensor类
- 还要定义一些计算类,来封装对于某个计算中自己的计算和求导方式
点
- 对于图来说,我们还是要建立图的
- 不过python的列表是动态的,所以肯定是邻接表会好一点
- 对于每一个点,记录他的儿子和父亲
- 记录值和导数,属于哪个图
封装的方法
backward_single,backward
1 | def backward_single(self, y): |
Copy
- backward_single是沿着儿子的,毕竟是递归方法,所以是单链的。后面的backward就是多链的
求导和计算
- 为什么要封装求导和计算呢
- 因为在各种继承类中有各种各样的求导和计算
重载运算符
1 | def __add__(self, other): |
Copy
运算
1 | class Add(Node): |
Copy
上面的各种运算类封装了自己的对应的运算和求导
图
图的方法就很简单
建立一个节点(没了)
1
2
3
4
5class Graph(object):
def __init__(self):
self.nodes = []
def add_node(self, node):
self.nodes.append(node)Copy
最后就能搞出一个简单的计算图