Skip to content

Commit 745e2a8

Browse files
committed
Added things
1 parent e7d3aa1 commit 745e2a8

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

fx/vmap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
#
2020
# How is this feat accomplished? One observation is that to "batch" a model, it
2121
# suffices to batch each individual operation. In other words, given an
22-
# operation that works on the current shape, how do we make it work on another
23-
# batch dimension? This leads us to batching rules.
22+
# operation that works on the current shape, how do we make it work with an
23+
# additional batch dimension? This leads us to batching rules.
2424
#
2525
# Batching Rules
2626
# ---------------
@@ -194,4 +194,4 @@ def forward(self, a, b):
194194
# outer product computation. ((B, N), (M,)) -> (B, N, M)
195195

196196
model = vmap(model, in_axes=(0, None), example_args=(x[0], y))
197-
print(model(x, y).shape) # ((3, 5), (2,)) -> (3, 5, 2)
197+
print(model(x, y).shape) # ((3, 5), (2,)) -> (3, 5, 2)

0 commit comments

Comments
 (0)