@@ -191,53 +191,54 @@ def check_for_block_diag(x):
191
191
)
192
192
193
193
# Check that the BlockDiagonal is an input to a Dot node:
194
- clients = list ( get_clients_at_depth (fgraph , node , depth = 1 ))
195
- if not clients or len ( clients ) > 1 or not isinstance (clients [ 0 ] .op , Dot ):
196
- return
194
+ for client in get_clients_at_depth (fgraph , node , depth = 1 ):
195
+ if not isinstance (client .op , Dot ):
196
+ return
197
197
198
- [dot_node ] = clients
199
- op = dot_node .op
200
- x , y = dot_node .inputs
198
+ op = client .op
199
+ x , y = client .inputs
201
200
202
- if not (check_for_block_diag (x ) or check_for_block_diag (y )):
203
- return None
201
+ if not (check_for_block_diag (x ) or check_for_block_diag (y )):
202
+ return None
204
203
205
- # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
206
- # non-block diagonal, and return a new block diagonal
207
- if check_for_block_diag (x ) and not check_for_block_diag (y ):
208
- components = x .owner .inputs
209
- y_splits = split (
210
- y ,
211
- splits_size = [component .shape [- 1 ] for component in components ],
212
- n_splits = len (components ),
213
- )
214
- new_components = [
215
- op (component , y_split ) for component , y_split in zip (components , y_splits )
216
- ]
217
- new_output = join (0 , * new_components )
218
-
219
- elif not check_for_block_diag (x ) and check_for_block_diag (y ):
220
- components = y .owner .inputs
221
- x_splits = split (
222
- x ,
223
- splits_size = [component .shape [0 ] for component in components ],
224
- n_splits = len (components ),
225
- axis = 1 ,
226
- )
204
+ # Case 1: Only one input is BlockDiagonal. In this case, multiply all components of the block-diagonal with the
205
+ # non-block diagonal, and return a new block diagonal
206
+ if check_for_block_diag (x ) and not check_for_block_diag (y ):
207
+ components = x .owner .inputs
208
+ y_splits = split (
209
+ y ,
210
+ splits_size = [component .shape [- 1 ] for component in components ],
211
+ n_splits = len (components ),
212
+ )
213
+ new_components = [
214
+ op (component , y_split )
215
+ for component , y_split in zip (components , y_splits )
216
+ ]
217
+ new_output = join (0 , * new_components )
218
+
219
+ elif not check_for_block_diag (x ) and check_for_block_diag (y ):
220
+ components = y .owner .inputs
221
+ x_splits = split (
222
+ x ,
223
+ splits_size = [component .shape [0 ] for component in components ],
224
+ n_splits = len (components ),
225
+ axis = 1 ,
226
+ )
227
227
228
- new_components = [
229
- op (x_split , component ) for component , x_split in zip (components , x_splits )
230
- ]
231
- new_output = join (1 , * new_components )
228
+ new_components = [
229
+ op (x_split , component )
230
+ for component , x_split in zip (components , x_splits )
231
+ ]
232
+ new_output = join (1 , * new_components )
232
233
233
- # Case 2: Both inputs are BlockDiagonal. Do nothing
234
- else :
235
- # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
236
- # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
237
- return None
234
+ # Case 2: Both inputs are BlockDiagonal. Do nothing
235
+ else :
236
+ # TODO: If shapes are statically known and all components have equal shapes, we could rewrite
237
+ # this case to block_diag(*[dot(comp_1, comp_2) for comp_1, comp_2 in zip(x.owner.inputs, y.owner.inputs)])
238
+ return None
238
239
239
- copy_stack_trace (node .outputs [0 ], new_output )
240
- return {dot_node .outputs [0 ]: new_output }
240
+ copy_stack_trace (node .outputs [0 ], new_output )
241
+ return {client .outputs [0 ]: new_output }
241
242
242
243
243
244
@register_canonicalize
0 commit comments