@@ -393,121 +393,12 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
393
393
394
394
395
395
def weights_torch (model , fmt = 'longform' , plot = 'boxplot' ):
396
+ from hls4ml .utils .profiling_utils import WeightsTorch
397
+
396
398
wt = WeightsTorch (model , fmt , plot )
397
399
return wt .get_weights ()
398
400
399
401
400
- def _torch_batchnorm (layer ):
401
- weights = list (layer .parameters ())
402
- epsilon = layer .eps
403
-
404
- gamma = weights [0 ]
405
- beta = weights [1 ]
406
- if layer .track_running_stats :
407
- mean = layer .running_mean
408
- var = layer .running_var
409
- else :
410
- mean = torch .tensor (np .ones (20 ))
411
- var = torch .tensor (np .zeros (20 ))
412
-
413
- scale = gamma / np .sqrt (var + epsilon )
414
- bias = beta - gamma * mean / np .sqrt (var + epsilon )
415
-
416
- return [scale , bias ], ['s' , 'b' ]
417
-
418
-
419
- def _torch_layer (layer ):
420
- return list (layer .parameters ()), ['w' , 'b' ]
421
-
422
-
423
- def _torch_rnn (layer ):
424
- return list (layer .parameters ()), ['w_ih_l0' , 'w_hh_l0' , 'b_ih_l0' , 'b_hh_l0' ]
425
-
426
-
427
- torch_process_layer_map = defaultdict (
428
- lambda : _torch_layer ,
429
- {
430
- 'BatchNorm1d' : _torch_batchnorm ,
431
- 'BatchNorm2d' : _torch_batchnorm ,
432
- 'RNN' : _torch_rnn ,
433
- 'LSTM' : _torch_rnn ,
434
- 'GRU' : _torch_rnn ,
435
- },
436
- )
437
-
438
-
439
- class WeightsTorch :
440
- def __init__ (self , model : torch .nn .Module , fmt : str = 'longform' , plot : str = 'boxplot' ) -> None :
441
- self .model = model
442
- self .fmt = fmt
443
- self .plot = plot
444
- self .registered_layers = list ()
445
- self ._find_layers (self .model , self .model .__class__ .__name__ )
446
-
447
- def _find_layers (self , model , module_name ):
448
- for name , module in model .named_children ():
449
- if isinstance (module , (torch .nn .Sequential , torch .nn .ModuleList )):
450
- self ._find_layers (module , module_name + "." + name )
451
- elif isinstance (module , (torch .nn .Module )) and self ._is_parameterized (module ):
452
- if len (list (module .named_children ())) != 0 :
453
- # custom nn.Module, continue search
454
- self ._find_layers (module , module_name + "." + name )
455
- else :
456
- self ._register_layer (module_name + "." + name )
457
-
458
- def _is_registered (self , name : str ) -> bool :
459
- return name in self .registered_layers
460
-
461
- def _register_layer (self , name : str ) -> None :
462
- if self ._is_registered (name ) is False :
463
- self .registered_layers .append (name )
464
-
465
- def _is_parameterized (self , module : torch .nn .Module ) -> bool :
466
- return any (p .requires_grad for p in module .parameters ())
467
-
468
- def _get_weights (self ) -> pandas .DataFrame | list [dict ]:
469
- if self .fmt == 'longform' :
470
- data = {'x' : [], 'layer' : [], 'weight' : []}
471
- elif self .fmt == 'summary' :
472
- data = []
473
- for layer_name in self .registered_layers :
474
- layer = self ._get_layer (layer_name , self .model )
475
- name = layer .__class__ .__name__
476
- weights , suffix = torch_process_layer_map [layer .__class__ .__name__ ](layer )
477
- for i , w in enumerate (weights ):
478
- label = f'{ name } /{ suffix [i ]} '
479
- w = weights [i ].detach ().numpy ()
480
- w = w .flatten ()
481
- w = abs (w [w != 0 ])
482
- n = len (w )
483
- if n == 0 :
484
- print (f'Weights for { name } are only zeros, ignoring.' )
485
- break
486
- if self .fmt == 'longform' :
487
- data ['x' ].extend (w .tolist ())
488
- data ['layer' ].extend ([name ] * n )
489
- data ['weight' ].extend ([label ] * n )
490
- elif self .fmt == 'summary' :
491
- data .append (array_to_summary (w , fmt = self .plot ))
492
- data [- 1 ]['layer' ] = name
493
- data [- 1 ]['weight' ] = label
494
-
495
- if self .fmt == 'longform' :
496
- data = pandas .DataFrame (data )
497
- return data
498
-
499
- def get_weights (self ) -> pandas .DataFrame | list [dict ]:
500
- return self ._get_weights ()
501
-
502
- def get_layers (self ) -> list [str ]:
503
- return self .registered_layers
504
-
505
- def _get_layer (self , layer_name : str , module : torch .nn .Module ) -> torch .nn .Module :
506
- for name in layer_name .split ('.' )[1 :]:
507
- module = getattr (module , name )
508
- return module
509
-
510
-
511
402
def activations_torch (model , X , fmt = 'longform' , plot = 'boxplot' ):
512
403
X = torch .Tensor (X )
513
404
if fmt == 'longform' :
0 commit comments