@@ -100,10 +100,10 @@ The pattern to calculate higher order gradients is the following:
100
100
``` python
101
101
from mxnet import ndarray as nd
102
102
from mxnet import autograd as ag
103
- x= nd.array([1 ,2 ,3 ])
103
+ x = nd.array([1 ,2 ,3 ])
104
104
x.attach_grad()
105
105
def f (x ):
106
- # A function which supports higher oder gradients
106
+ # Any function which supports higher oder gradient
107
107
return nd.log(x)
108
108
```
109
109
@@ -117,28 +117,28 @@ Using mxnet.autograd.grad multiple times:
117
117
``` python
118
118
with ag.record():
119
119
y = f(x)
120
- x_grad = ag.grad(y, x, create_graph = True , retain_graph = True )[0 ]
121
- x_grad_grad = ag.grad(x_grad, x, create_graph = False , retain_graph = True )[0 ]
122
- print (f " dy /dx: { x_grad} " )
123
- print (f " d2y /dx2: { x_grad_grad} " )
120
+ x_grad = ag.grad(heads = y, variables = x, create_graph = True , retain_graph = True )[0 ]
121
+ x_grad_grad = ag.grad(heads = x_grad, variables = x, create_graph = False , retain_graph = False )[0 ]
122
+ print (f " dL /dx: { x_grad} " )
123
+ print (f " d2L /dx2: { x_grad_grad} " )
124
124
```
125
125
126
126
Running backward on the backward graph:
127
127
128
128
``` python
129
129
with ag.record():
130
130
y = f(x)
131
- x_grad = ag.grad(y, x, create_graph = True , retain_graph = True )[0 ]
131
+ x_grad = ag.grad(heads = y, variables = x, create_graph = True , retain_graph = True )[0 ]
132
132
x_grad.backward()
133
133
x_grad_grad = x.grad
134
- print (f " dy /dx: { x_grad} " )
135
- print (f " d2y /dx2: { x_grad_grad} " )
134
+ print (f " dL /dx: { x_grad} " )
135
+ print (f " d2L /dx2: { x_grad_grad} " )
136
136
137
137
```
138
138
139
139
Both methods are equivalent, except that in the second case, retain_graph on running backward is set
140
140
to False by default. But both calls are running a backward pass as on the graph as usual to get the
141
- gradient of the first gradient ` y_grad ` with respect to ` x ` evaluated at the value of ` x ` .
141
+ gradient of the first gradient ` x_grad ` with respect to ` x ` evaluated at the value of ` x ` .
142
142
143
143
144
144
0 commit comments