@@ -381,15 +381,87 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
381
381
382
382
383
383
def weights_torch (model , fmt = 'longform' , plot = 'boxplot' ):
384
- suffix = ['w' , 'b' ]
385
- if fmt == 'longform' :
386
- data = {'x' : [], 'layer' : [], 'weight' : []}
387
- elif fmt == 'summary' :
388
- data = []
389
- for layer in model .children ():
390
- if isinstance (layer , torch .nn .Linear ):
384
+ wt = WeightsTorch (model , fmt , plot )
385
+ return wt .get_weights ()
386
+
387
+
388
+ def _torch_batchnorm (layer ):
389
+ weights = list (layer .parameters ())
390
+ epsilon = layer .eps
391
+
392
+ gamma = weights [0 ]
393
+ beta = weights [1 ]
394
+ if layer .track_running_stats :
395
+ mean = layer .running_mean
396
+ var = layer .running_var
397
+ else :
398
+ mean = torch .tensor (np .ones (20 ))
399
+ var = torch .tensor (np .zeros (20 ))
400
+
401
+ scale = gamma / np .sqrt (var + epsilon )
402
+ bias = beta - gamma * mean / np .sqrt (var + epsilon )
403
+
404
+ return [scale , bias ], ['s' , 'b' ]
405
+
406
+
407
+ def _torch_layer (layer ):
408
+ return list (layer .parameters ()), ['w' , 'b' ]
409
+
410
+
411
+ def _torch_rnn (layer ):
412
+ return list (layer .parameters ()), ['w_ih_l0' , 'w_hh_l0' , 'b_ih_l0' , 'b_hh_l0' ]
413
+
414
+
415
+ torch_process_layer_map = defaultdict (
416
+ lambda : _torch_layer ,
417
+ {
418
+ 'BatchNorm1d' : _torch_batchnorm ,
419
+ 'BatchNorm2d' : _torch_batchnorm ,
420
+ 'RNN' : _torch_rnn ,
421
+ 'LSTM' : _torch_rnn ,
422
+ 'GRU' : _torch_rnn ,
423
+ },
424
+ )
425
+
426
+
427
+ class WeightsTorch :
428
+ def __init__ (self , model : torch .nn .Module , fmt : str = 'longform' , plot : str = 'boxplot' ) -> None :
429
+ self .model = model
430
+ self .fmt = fmt
431
+ self .plot = plot
432
+ self .registered_layers = list ()
433
+ self ._find_layers (self .model , self .model .__class__ .__name__ )
434
+
435
+ def _find_layers (self , model , module_name ):
436
+ for name , module in model .named_children ():
437
+ if isinstance (module , (torch .nn .Sequential , torch .nn .ModuleList )):
438
+ self ._find_layers (module , module_name + "." + name )
439
+ elif isinstance (module , (torch .nn .Module )) and self ._is_parameterized (module ):
440
+ if len (list (module .named_children ())) != 0 :
441
+ # custom nn.Module, continue search
442
+ self ._find_layers (module , module_name + "." + name )
443
+ else :
444
+ self ._register_layer (module_name + "." + name )
445
+
446
+ def _is_registered (self , name : str ) -> bool :
447
+ return name in self .registered_layers
448
+
449
+ def _register_layer (self , name : str ) -> None :
450
+ if self ._is_registered (name ) is False :
451
+ self .registered_layers .append (name )
452
+
453
+ def _is_parameterized (self , module : torch .nn .Module ) -> bool :
454
+ return any (p .requires_grad for p in module .parameters ())
455
+
456
+ def _get_weights (self ) -> pandas .DataFrame | list [dict ]:
457
+ if self .fmt == 'longform' :
458
+ data = {'x' : [], 'layer' : [], 'weight' : []}
459
+ elif self .fmt == 'summary' :
460
+ data = []
461
+ for layer_name in self .registered_layers :
462
+ layer = self ._get_layer (layer_name , self .model )
391
463
name = layer .__class__ .__name__
392
- weights = list ( layer .parameters () )
464
+ weights , suffix = torch_process_layer_map [ layer .__class__ . __name__ ]( layer )
393
465
for i , w in enumerate (weights ):
394
466
label = f'{ name } /{ suffix [i ]} '
395
467
w = weights [i ].detach ().numpy ()
@@ -399,18 +471,29 @@ def weights_torch(model, fmt='longform', plot='boxplot'):
399
471
if n == 0 :
400
472
print (f'Weights for { name } are only zeros, ignoring.' )
401
473
break
402
- if fmt == 'longform' :
474
+ if self . fmt == 'longform' :
403
475
data ['x' ].extend (w .tolist ())
404
476
data ['layer' ].extend ([name ] * n )
405
477
data ['weight' ].extend ([label ] * n )
406
- elif fmt == 'summary' :
407
- data .append (array_to_summary (w , fmt = plot ))
478
+ elif self . fmt == 'summary' :
479
+ data .append (array_to_summary (w , fmt = self . plot ))
408
480
data [- 1 ]['layer' ] = name
409
481
data [- 1 ]['weight' ] = label
410
482
411
- if fmt == 'longform' :
412
- data = pandas .DataFrame (data )
413
- return data
483
+ if self .fmt == 'longform' :
484
+ data = pandas .DataFrame (data )
485
+ return data
486
+
487
+ def get_weights (self ) -> pandas .DataFrame | list [dict ]:
488
+ return self ._get_weights ()
489
+
490
+ def get_layers (self ) -> list [str ]:
491
+ return self .registered_layers
492
+
493
+ def _get_layer (self , layer_name : str , module : torch .nn .Module ) -> torch .nn .Module :
494
+ for name in layer_name .split ('.' )[1 :]:
495
+ module = getattr (module , name )
496
+ return module
414
497
415
498
416
499
def activations_torch (model , X , fmt = 'longform' , plot = 'boxplot' ):
@@ -484,11 +567,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
484
567
elif model_present :
485
568
if __tf_profiling_enabled__ and isinstance (model , keras .Model ):
486
569
data = weights_keras (model , fmt = 'summary' , plot = plot )
487
- elif __torch_profiling_enabled__ and isinstance (model , torch .nn .Sequential ):
570
+ elif __torch_profiling_enabled__ and isinstance (model , torch .nn .Module ):
488
571
data = weights_torch (model , fmt = 'summary' , plot = plot )
489
572
490
573
if data is None :
491
- print ("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled" )
574
+ print ("Only keras, PyTorch and ModelGraph models " + "can currently be profiled" )
492
575
493
576
if hls_model_present and os .path .exists (tmp_output_dir ):
494
577
shutil .rmtree (tmp_output_dir )
0 commit comments