File tree Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Expand file tree Collapse file tree 2 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -119,8 +119,8 @@ def convolve1d(
119
119
if mode == "same" :
120
120
# We implement "same" as "valid" with padded `in1`.
121
121
in1_batch_shape = tuple (in1 .shape )[:- 1 ]
122
- zeros_left = in2 .shape [0 ] // 2
123
- zeros_right = (in2 .shape [0 ] - 1 ) // 2
122
+ zeros_left = in2 .shape [- 1 ] // 2
123
+ zeros_right = (in2 .shape [- 1 ] - 1 ) // 2
124
124
in1 = join (
125
125
- 1 ,
126
126
zeros ((* in1_batch_shape , zeros_left ), dtype = in2 .dtype ),
Original file line number Diff line number Diff line change @@ -47,3 +47,16 @@ def test_convolve1d_batch():
47
47
res_np = np .convolve (x_test [0 ], y_test [0 ])
48
48
np .testing .assert_allclose (res [0 ], res_np , rtol = rtol )
49
49
np .testing .assert_allclose (res [1 ], res_np , rtol = rtol )
50
+
51
+
52
+ def test_convolve1d_batch_same ():
53
+ x = matrix ("data" )
54
+ y = matrix ("kernel" )
55
+ out = convolve1d (x , y , mode = "same" )
56
+
57
+ rng = np .random .default_rng (38 )
58
+ x_test = rng .normal (size = (2 , 8 )).astype (x .dtype )
59
+ y_test = rng .normal (size = (2 , 8 )).astype (x .dtype )
60
+
61
+ res = out .eval ({x : x_test , y : y_test })
62
+ assert res .shape == (2 , 8 )
You can’t perform that action at this time.
0 commit comments