Skip to content

Commit fd594e0

Browse files
jicampospre-commit-ci[bot]JanFSchulte
authored
Update Torch profiler (#1156)
* updated pytorch weight profiler * fix type * [pre-commit.ci] auto fixes from pre-commit hooks * update comparison to false * fixed numerical condition for pytorch models and updates to type hints * Create test_pytorch_profiler.py * Update layer processing and add batchnorm testing * Remove typo --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jan-Frederik Schulte <[email protected]>
1 parent fb07e9c commit fd594e0

File tree

2 files changed

+184
-16
lines changed

2 files changed

+184
-16
lines changed

hls4ml/model/profiling.py

Lines changed: 99 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,87 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
381381

382382

383383
def weights_torch(model, fmt='longform', plot='boxplot'):
384-
suffix = ['w', 'b']
385-
if fmt == 'longform':
386-
data = {'x': [], 'layer': [], 'weight': []}
387-
elif fmt == 'summary':
388-
data = []
389-
for layer in model.children():
390-
if isinstance(layer, torch.nn.Linear):
384+
wt = WeightsTorch(model, fmt, plot)
385+
return wt.get_weights()
386+
387+
388+
def _torch_batchnorm(layer):
389+
weights = list(layer.parameters())
390+
epsilon = layer.eps
391+
392+
gamma = weights[0]
393+
beta = weights[1]
394+
if layer.track_running_stats:
395+
mean = layer.running_mean
396+
var = layer.running_var
397+
else:
398+
mean = torch.tensor(np.ones(20))
399+
var = torch.tensor(np.zeros(20))
400+
401+
scale = gamma / np.sqrt(var + epsilon)
402+
bias = beta - gamma * mean / np.sqrt(var + epsilon)
403+
404+
return [scale, bias], ['s', 'b']
405+
406+
407+
def _torch_layer(layer):
408+
return list(layer.parameters()), ['w', 'b']
409+
410+
411+
def _torch_rnn(layer):
412+
return list(layer.parameters()), ['w_ih_l0', 'w_hh_l0', 'b_ih_l0', 'b_hh_l0']
413+
414+
415+
torch_process_layer_map = defaultdict(
416+
lambda: _torch_layer,
417+
{
418+
'BatchNorm1d': _torch_batchnorm,
419+
'BatchNorm2d': _torch_batchnorm,
420+
'RNN': _torch_rnn,
421+
'LSTM': _torch_rnn,
422+
'GRU': _torch_rnn,
423+
},
424+
)
425+
426+
427+
class WeightsTorch:
428+
def __init__(self, model: torch.nn.Module, fmt: str = 'longform', plot: str = 'boxplot') -> None:
429+
self.model = model
430+
self.fmt = fmt
431+
self.plot = plot
432+
self.registered_layers = list()
433+
self._find_layers(self.model, self.model.__class__.__name__)
434+
435+
def _find_layers(self, model, module_name):
436+
for name, module in model.named_children():
437+
if isinstance(module, (torch.nn.Sequential, torch.nn.ModuleList)):
438+
self._find_layers(module, module_name + "." + name)
439+
elif isinstance(module, (torch.nn.Module)) and self._is_parameterized(module):
440+
if len(list(module.named_children())) != 0:
441+
# custom nn.Module, continue search
442+
self._find_layers(module, module_name + "." + name)
443+
else:
444+
self._register_layer(module_name + "." + name)
445+
446+
def _is_registered(self, name: str) -> bool:
447+
return name in self.registered_layers
448+
449+
def _register_layer(self, name: str) -> None:
450+
if self._is_registered(name) is False:
451+
self.registered_layers.append(name)
452+
453+
def _is_parameterized(self, module: torch.nn.Module) -> bool:
454+
return any(p.requires_grad for p in module.parameters())
455+
456+
def _get_weights(self) -> pandas.DataFrame | list[dict]:
457+
if self.fmt == 'longform':
458+
data = {'x': [], 'layer': [], 'weight': []}
459+
elif self.fmt == 'summary':
460+
data = []
461+
for layer_name in self.registered_layers:
462+
layer = self._get_layer(layer_name, self.model)
391463
name = layer.__class__.__name__
392-
weights = list(layer.parameters())
464+
weights, suffix = torch_process_layer_map[layer.__class__.__name__](layer)
393465
for i, w in enumerate(weights):
394466
label = f'{name}/{suffix[i]}'
395467
w = weights[i].detach().numpy()
@@ -399,18 +471,29 @@ def weights_torch(model, fmt='longform', plot='boxplot'):
399471
if n == 0:
400472
print(f'Weights for {name} are only zeros, ignoring.')
401473
break
402-
if fmt == 'longform':
474+
if self.fmt == 'longform':
403475
data['x'].extend(w.tolist())
404476
data['layer'].extend([name] * n)
405477
data['weight'].extend([label] * n)
406-
elif fmt == 'summary':
407-
data.append(array_to_summary(w, fmt=plot))
478+
elif self.fmt == 'summary':
479+
data.append(array_to_summary(w, fmt=self.plot))
408480
data[-1]['layer'] = name
409481
data[-1]['weight'] = label
410482

411-
if fmt == 'longform':
412-
data = pandas.DataFrame(data)
413-
return data
483+
if self.fmt == 'longform':
484+
data = pandas.DataFrame(data)
485+
return data
486+
487+
def get_weights(self) -> pandas.DataFrame | list[dict]:
488+
return self._get_weights()
489+
490+
def get_layers(self) -> list[str]:
491+
return self.registered_layers
492+
493+
def _get_layer(self, layer_name: str, module: torch.nn.Module) -> torch.nn.Module:
494+
for name in layer_name.split('.')[1:]:
495+
module = getattr(module, name)
496+
return module
414497

415498

416499
def activations_torch(model, X, fmt='longform', plot='boxplot'):
@@ -484,11 +567,11 @@ def numerical(model=None, hls_model=None, X=None, plot='boxplot'):
484567
elif model_present:
485568
if __tf_profiling_enabled__ and isinstance(model, keras.Model):
486569
data = weights_keras(model, fmt='summary', plot=plot)
487-
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Sequential):
570+
elif __torch_profiling_enabled__ and isinstance(model, torch.nn.Module):
488571
data = weights_torch(model, fmt='summary', plot=plot)
489572

490573
if data is None:
491-
print("Only keras, PyTorch (Sequential) and ModelGraph models " + "can currently be profiled")
574+
print("Only keras, PyTorch and ModelGraph models " + "can currently be profiled")
492575

493576
if hls_model_present and os.path.exists(tmp_output_dir):
494577
shutil.rmtree(tmp_output_dir)

test/pytest/test_pytorch_profiler.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import pytest
2+
3+
import hls4ml
4+
5+
try:
6+
import torch
7+
import torch.nn as nn
8+
9+
__torch_profiling_enabled__ = True
10+
except ImportError:
11+
__torch_profiling_enabled__ = False
12+
13+
14+
class SubClassModel(torch.nn.Module):
15+
def __init__(self, layers) -> None:
16+
super().__init__()
17+
for idx, layer in enumerate(layers):
18+
setattr(self, f'layer_{idx}', layer)
19+
20+
21+
class ModuleListModel(torch.nn.Module):
22+
def __init__(self, layers) -> None:
23+
super().__init__()
24+
self.layer = torch.nn.ModuleList(layers)
25+
26+
27+
class NestedSequentialModel(torch.nn.Module):
28+
def __init__(self, layers) -> None:
29+
super().__init__()
30+
self.model = torch.nn.Sequential(*layers)
31+
32+
33+
def count_bars_in_figure(fig):
34+
count = 0
35+
for ax in fig.get_axes():
36+
count += len(ax.patches)
37+
return count
38+
39+
40+
# Reusable parameter list
41+
test_layers = [
42+
(4, [nn.Linear(10, 20), nn.Linear(20, 5)]),
43+
(3, [nn.Linear(10, 20), nn.BatchNorm1d(20)]),
44+
(6, [nn.Linear(10, 20), nn.Linear(20, 5), nn.Conv1d(3, 16, kernel_size=3)]),
45+
(6, [nn.Linear(15, 30), nn.Linear(30, 15), nn.Conv2d(1, 32, kernel_size=3)]),
46+
(6, [nn.RNN(64, 128), nn.Linear(128, 10)]),
47+
(6, [nn.LSTM(64, 128), nn.Linear(128, 10)]),
48+
(6, [nn.GRU(64, 128), nn.Linear(128, 10)]),
49+
]
50+
51+
52+
@pytest.mark.parametrize("layers", test_layers)
53+
def test_sequential_model(layers):
54+
if __torch_profiling_enabled__:
55+
param_count, layers = layers
56+
model = torch.nn.Sequential(*layers)
57+
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
58+
assert count_bars_in_figure(wp) == param_count
59+
60+
61+
@pytest.mark.parametrize("layers", test_layers)
62+
def test_subclass_model(layers):
63+
if __torch_profiling_enabled__:
64+
param_count, layers = layers
65+
model = SubClassModel(layers)
66+
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
67+
assert count_bars_in_figure(wp) == param_count
68+
69+
70+
@pytest.mark.parametrize("layers", test_layers)
71+
def test_modulelist_model(layers):
72+
if __torch_profiling_enabled__:
73+
param_count, layers = layers
74+
model = ModuleListModel(layers)
75+
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
76+
assert count_bars_in_figure(wp) == param_count
77+
78+
79+
@pytest.mark.parametrize("layers", test_layers)
80+
def test_nested_model(layers):
81+
if __torch_profiling_enabled__:
82+
param_count, layers = layers
83+
model = NestedSequentialModel(layers)
84+
wp, _, _, _ = hls4ml.model.profiling.numerical(model)
85+
assert count_bars_in_figure(wp) == param_count

0 commit comments

Comments
 (0)