Skip to content

Commit ad86cb6

Browse files
Remove dependence of profiling tools on torch (#1233)
* move WeightsTorch class to utils to avoid dependence of profiling tools on torch * [pre-commit.ci] auto fixes from pre-commit hooks * add missing import --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 88f1c5f commit ad86cb6

File tree

2 files changed

+121
-111
lines changed

2 files changed

+121
-111
lines changed

hls4ml/model/profiling.py

Lines changed: 2 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -393,121 +393,12 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
393393

394394

395395
def weights_torch(model, fmt='longform', plot='boxplot'):
396+
from hls4ml.utils.profiling_utils import WeightsTorch
397+
396398
wt = WeightsTorch(model, fmt, plot)
397399
return wt.get_weights()
398400

399401

400-
def _torch_batchnorm(layer):
401-
weights = list(layer.parameters())
402-
epsilon = layer.eps
403-
404-
gamma = weights[0]
405-
beta = weights[1]
406-
if layer.track_running_stats:
407-
mean = layer.running_mean
408-
var = layer.running_var
409-
else:
410-
mean = torch.tensor(np.ones(20))
411-
var = torch.tensor(np.zeros(20))
412-
413-
scale = gamma / np.sqrt(var + epsilon)
414-
bias = beta - gamma * mean / np.sqrt(var + epsilon)
415-
416-
return [scale, bias], ['s', 'b']
417-
418-
419-
def _torch_layer(layer):
420-
return list(layer.parameters()), ['w', 'b']
421-
422-
423-
def _torch_rnn(layer):
424-
return list(layer.parameters()), ['w_ih_l0', 'w_hh_l0', 'b_ih_l0', 'b_hh_l0']
425-
426-
427-
torch_process_layer_map = defaultdict(
428-
lambda: _torch_layer,
429-
{
430-
'BatchNorm1d': _torch_batchnorm,
431-
'BatchNorm2d': _torch_batchnorm,
432-
'RNN': _torch_rnn,
433-
'LSTM': _torch_rnn,
434-
'GRU': _torch_rnn,
435-
},
436-
)
437-
438-
439-
class WeightsTorch:
440-
def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None:
441-
self.model = model
442-
self.fmt = fmt
443-
self.plot = plot
444-
self.registered_layers = list()
445-
self._find_layers(self.model, self.model.__class__.__name__)
446-
447-
def _find_layers(self, model, module_name):
448-
for name, module in model.named_children():
449-
if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)):
450-
self._find_layers(module, module_name + "." + name)
451-
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
452-
if len(list(module.named_children())) != 0:
453-
# custom nn.Module, continue search
454-
self._find_layers(module, module_name + "." + name)
455-
else:
456-
self._register_layer(module_name + "." + name)
457-
458-
def _is_registered(self, name: str) -> bool:
459-
return name in self.registered_layers
460-
461-
def _register_layer(self, name: str) -> None:
462-
if self._is_registered(name) is False:
463-
self.registered_layers.append(name)
464-
465-
def _is_parameterized(self, module: torch.nn.Module) -> bool:
466-
return any(p.requires_grad for p in module.parameters())
467-
468-
def _get_weights(self) -> pandas.DataFrame | list[dict]:
469-
if self.fmt == 'longform':
470-
data = {'x': [], 'layer': [], 'weight': []}
471-
elif self.fmt == 'summary':
472-
data = []
473-
for layer_name in self.registered_layers:
474-
layer = self._get_layer(layer_name, self.model)
475-
name = layer.__class__.__name__
476-
weights, suffix = torch_process_layer_map[layer.__class__.__name__](layer)
477-
for i, w in enumerate(weights):
478-
label = f'{name}/{suffix[i]}'
479-
w = weights[i].detach().numpy()
480-
w = w.flatten()
481-
w = abs(w[w != 0])
482-
n = len(w)
483-
if n == 0:
484-
print(f'Weights for {name} are only zeros, ignoring.')
485-
break
486-
if self.fmt == 'longform':
487-
data['x'].extend(w.tolist())
488-
data['layer'].extend([name] * n)
489-
data['weight'].extend([label] * n)
490-
elif self.fmt == 'summary':
491-
data.append(array_to_summary(w, fmt=self.plot))
492-
data[-1]['layer'] = name
493-
data[-1]['weight'] = label
494-
495-
if self.fmt == 'longform':
496-
data = pandas.DataFrame(data)
497-
return data
498-
499-
def get_weights(self) -> pandas.DataFrame | list[dict]:
500-
return self._get_weights()
501-
502-
def get_layers(self) -> list[str]:
503-
return self.registered_layers
504-
505-
def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
506-
for name in layer_name.split('.')[1:]:
507-
module = getattr(module, name)
508-
return module
509-
510-
511402
def activations_torch(model, X, fmt='longform', plot='boxplot'):
512403
X = torch.Tensor(X)
513404
if fmt == 'longform':

hls4ml/utils/profiling_utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from collections import defaultdict
2+
3+
import numpy as np
4+
import pandas
5+
import torch
6+
7+
from hls4ml.model.profiling import array_to_summary
8+
9+
10+
def _torch_batchnorm(layer):
11+
weights = list(layer.parameters())
12+
epsilon = layer.eps
13+
14+
gamma = weights[0]
15+
beta = weights[1]
16+
if layer.track_running_stats:
17+
mean = layer.running_mean
18+
var = layer.running_var
19+
else:
20+
mean = torch.tensor(np.ones(20))
21+
var = torch.tensor(np.zeros(20))
22+
23+
scale = gamma / np.sqrt(var + epsilon)
24+
bias = beta - gamma * mean / np.sqrt(var + epsilon)
25+
26+
return [scale, bias], ['s', 'b']
27+
28+
29+
def _torch_layer(layer):
30+
return list(layer.parameters()), ['w', 'b']
31+
32+
33+
def _torch_rnn(layer):
34+
return list(layer.parameters()), ['w_ih_l0', 'w_hh_l0', 'b_ih_l0', 'b_hh_l0']
35+
36+
37+
torch_process_layer_map = defaultdict(
38+
lambda: _torch_layer,
39+
{
40+
'BatchNorm1d': _torch_batchnorm,
41+
'BatchNorm2d': _torch_batchnorm,
42+
'RNN': _torch_rnn,
43+
'LSTM': _torch_rnn,
44+
'GRU': _torch_rnn,
45+
},
46+
)
47+
48+
49+
class WeightsTorch:
50+
51+
def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None:
52+
self.model = model
53+
self.fmt = fmt
54+
self.plot = plot
55+
self.registered_layers = list()
56+
self._find_layers(self.model, self.model.__class__.__name__)
57+
58+
def _find_layers(self, model, module_name):
59+
for name, module in model.named_children():
60+
if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)):
61+
self._find_layers(module, module_name + "." + name)
62+
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
63+
if len(list(module.named_children())) != 0:
64+
# custom nn.Module, continue search
65+
self._find_layers(module, module_name + "." + name)
66+
else:
67+
self._register_layer(module_name + "." + name)
68+
69+
def _is_registered(self, name: str) -> bool:
70+
return name in self.registered_layers
71+
72+
def _register_layer(self, name: str) -> None:
73+
if self._is_registered(name) is False:
74+
self.registered_layers.append(name)
75+
76+
def _is_parameterized(self, module: torch.nn.Module) -> bool:
77+
return any(p.requires_grad for p in module.parameters())
78+
79+
def _get_weights(self) -> pandas.DataFrame | list[dict]:
80+
if self.fmt == 'longform':
81+
data = {'x': [], 'layer': [], 'weight': []}
82+
elif self.fmt == 'summary':
83+
data = []
84+
for layer_name in self.registered_layers:
85+
layer = self._get_layer(layer_name, self.model)
86+
name = layer.__class__.__name__
87+
weights, suffix = torch_process_layer_map[layer.__class__.__name__](layer)
88+
for i, w in enumerate(weights):
89+
label = f'{name}/{suffix[i]}'
90+
w = weights[i].detach().numpy()
91+
w = w.flatten()
92+
w = abs(w[w != 0])
93+
n = len(w)
94+
if n == 0:
95+
print(f'Weights for {name} are only zeros, ignoring.')
96+
break
97+
if self.fmt == 'longform':
98+
data['x'].extend(w.tolist())
99+
data['layer'].extend([name] * n)
100+
data['weight'].extend([label] * n)
101+
elif self.fmt == 'summary':
102+
data.append(array_to_summary(w, fmt=self.plot))
103+
data[-1]['layer'] = name
104+
data[-1]['weight'] = label
105+
106+
if self.fmt == 'longform':
107+
data = pandas.DataFrame(data)
108+
return data
109+
110+
def get_weights(self) -> pandas.DataFrame | list[dict]:
111+
return self._get_weights()
112+
113+
def get_layers(self) -> list[str]:
114+
return self.registered_layers
115+
116+
def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
117+
for name in layer_name.split('.')[1:]:
118+
module = getattr(module, name)
119+
return module

0 commit comments

Comments
 (0)