Skip to content

Commit fe5bdb2

Browse files
bottlerfacebook-github-bot
authored andcommitted
different learning rate for different parts
Summary: Adds the ability to have different learning rates for different parts of the model. The trainable parts of the implicitron have a new member param_groups: dictionary where keys are names of individual parameters, or module’s members and values are the parameter group where the parameter/member will be sorted to. "self" key is used to denote the parameter group at the module level. Possible keys, including the "self" key do not have to be defined. By default all parameters are put into "default" parameter group and have the learning rate defined in the optimizer, it can be overriden at the: - module level with “self” key, all the parameters and child module s parameters will be put to that parameter group - member level, which is the same as if the `param_groups` in that member has key=“self” and value equal to that parameter group. This is useful if members do not have `param_groups`, for example torch.nn.Linear. - parameter level, parameter with the same name as the key will be put to that parameter group. And in the optimizer factory, parameters and their learning rates are recursively gathered. Reviewed By: shapovalov Differential Revision: D40145802 fbshipit-source-id: 631c02b8d79ee1c0eb4c31e6e42dbd3d2882078a
1 parent a819ecb commit fe5bdb2

File tree

6 files changed

+293
-5
lines changed

6 files changed

+293
-5
lines changed

projects/implicitron_trainer/impl/optimizer_factory.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import inspect
88
import logging
99
import os
10-
from typing import Any, Dict, Optional, Tuple
10+
from collections import defaultdict
11+
from dataclasses import field
12+
from typing import Any, Dict, List, Optional, Tuple
1113

1214
import torch.optim
1315

@@ -64,6 +66,12 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
6466
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
6567
foreach: Whether to use new "foreach" implementation of optimizer where
6668
available (e.g. requires PyTorch 1.12.0 for Adam)
69+
group_learning_rates: Parameters or modules can be assigned to parameter
70+
groups. This dictionary has names of those parameter groups as keys
71+
and learning rates as values. All parameter group names have to be
72+
defined in this dictionary. Parameters which do not have predefined
73+
parameter group are put into "default" parameter group which has
74+
`lr` as its learning rate.
6775
"""
6876

6977
betas: Tuple[float, ...] = (0.9, 0.999)
@@ -78,6 +86,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
7886
linear_exponential_lr_milestone: int = 200
7987
linear_exponential_start_gamma: float = 0.1
8088
foreach: Optional[bool] = True
89+
group_learning_rates: Dict[str, float] = field(default_factory=lambda: {})
8190

8291
def __post_init__(self):
8392
run_auto_creation(self)
@@ -115,8 +124,10 @@ def __call__(
115124
# pyre-ignore[29]
116125
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
117126
else:
118-
allprm = [prm for prm in model.parameters() if prm.requires_grad]
119-
p_groups = [{"params": allprm, "lr": self.lr}]
127+
p_groups = [
128+
{"params": params, "lr": self._get_group_learning_rate(group)}
129+
for group, params in self._get_param_groups(model).items()
130+
]
120131

121132
# Intialize the optimizer
122133
optimizer_kwargs: Dict[str, Any] = {
@@ -233,3 +244,82 @@ def _get_optimizer_state(
233244
else:
234245
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
235246
return optimizer_state
247+
248+
def _get_param_groups(
249+
self, module: torch.nn.Module
250+
) -> Dict[str, List[torch.nn.Parameter]]:
251+
"""
252+
Recursively visits all the modules inside the `module` and sorts all the
253+
parameters in parameter groups.
254+
255+
Uses `param_groups` dictionary member, where keys are names of individual
256+
parameters or module members and values are the names of the parameter groups
257+
for those parameters or members. "self" key is used to denote the parameter groups
258+
at the module level. Possible keys, including the "self" key do not have to
259+
be defined. By default all parameters have the learning rate defined in the
260+
optimizer. This can be overridden by setting the parameter group in `param_groups`
261+
member of a specific module, it can be overridden at the:
262+
- module level with “self” key, all the parameters and child
263+
module's parameters will inherit it
264+
- member level, which is the same as if the `param_groups` in that
265+
member has key=“self” and value equal to that parameter group.
266+
This is useful if members do not have `param_groups`, for
267+
example torch.nn.Linear.
268+
- parameter level, only parameter with the same name as the key
269+
will have it.
270+
271+
Args:
272+
module: module from which to extract the parameters and their parameter
273+
groups
274+
Returns:
275+
dictionary with parameter groups as keys and lists of parameters as values
276+
"""
277+
278+
param_groups = defaultdict(list)
279+
280+
def traverse(module, default_group):
281+
# If key self is defined in param_groups then chenge the default param
282+
# group for all parameters and children in the module.
283+
if hasattr(module, "param_groups") and "self" in module.param_groups:
284+
default_group = module.param_groups["self"]
285+
286+
# Collect all the parameters that are directly inside the `module`,
287+
# they will be in the default param group if they don't have
288+
# defined group.
289+
for name, param in module.named_parameters(recurse=False):
290+
if param.requires_grad:
291+
if hasattr(module, "param_groups") and name in module.param_groups:
292+
param_groups[module.param_groups[name]].append(param)
293+
else:
294+
param_groups[default_group].append(param)
295+
296+
# If children have defined default param group then use it else pass
297+
# own default.
298+
for child_name, child in module.named_children():
299+
if (
300+
hasattr(module, "param_groups")
301+
and child_name in module.param_groups
302+
):
303+
traverse(child, module.param_groups[child_name])
304+
else:
305+
traverse(child, default_group)
306+
307+
traverse(module, "default")
308+
return param_groups
309+
310+
def _get_group_learning_rate(self, group_name: str) -> float:
311+
"""
312+
Wraps the `group_learning_rates` dictionary providing errors and returns
313+
`self.lr` for "default" group_name.
314+
315+
Args:
316+
group_name: a string representing the name of the group
317+
Returns:
318+
learning rate for a specific group
319+
"""
320+
if group_name == "default":
321+
return self.lr
322+
lr = self.group_learning_rates.get(group_name, None)
323+
if lr is None:
324+
raise ValueError(f"no learning rate given for group {group_name}")
325+
return lr

projects/implicitron_trainer/tests/experiment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
409409
linear_exponential_lr_milestone: 200
410410
linear_exponential_start_gamma: 0.1
411411
foreach: true
412+
group_learning_rates: {}
412413
training_loop_ImplicitronTrainingLoop_args:
413414
evaluator_class_type: ImplicitronEvaluator
414415
evaluator_ImplicitronEvaluator_args:
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import unittest
9+
10+
import torch
11+
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
12+
13+
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
14+
15+
internal = os.environ.get("FB_TEST", False)
16+
17+
18+
class TestOptimizerFactory(unittest.TestCase):
19+
def setUp(self) -> None:
20+
torch.manual_seed(42)
21+
expand_args_fields(ImplicitronOptimizerFactory)
22+
23+
def _get_param_groups(self, model):
24+
default_cfg = get_default_args(ImplicitronOptimizerFactory)
25+
factory = ImplicitronOptimizerFactory(default_cfg)
26+
return factory._get_param_groups(model)
27+
28+
def _assert_allin(self, a, param_groups, key):
29+
with self.subTest(f"Testing key {key}"):
30+
b = param_groups[key]
31+
for el in a:
32+
if el not in b:
33+
raise ValueError(
34+
f"Element {el}\n\n from:\n\n {a}\n\n not in:\n\n {b}\n\n."
35+
+ f" Full param groups = \n\n{param_groups}"
36+
)
37+
for el in b:
38+
if el not in a:
39+
raise ValueError(
40+
f"Element {el}\n\n from:\n\n {b}\n\n not in:\n\n {a}\n\n."
41+
+ f" Full param groups = \n\n{param_groups}"
42+
)
43+
44+
def test_default_param_group_assignment(self):
45+
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
46+
na, nb = Node(params=[pa]), Node(params=[pb])
47+
root = Node(children=[na, nb], params=[pc])
48+
param_groups = self._get_param_groups(root)
49+
self._assert_allin([pa, pb, pc], param_groups, "default")
50+
51+
def test_member_overrides_default_param_group_assignment(self):
52+
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
53+
na, nb = Node(params=[pa]), Node(params=[pb])
54+
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb"})
55+
param_groups = self._get_param_groups(root)
56+
self._assert_allin([pa, pc], param_groups, "default")
57+
self._assert_allin([pb], param_groups, "pb")
58+
59+
def test_self_overrides_member_param_group_assignment(self):
60+
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
61+
na, nb = Node(params=[pa]), Node(params=[pb], param_groups={"self": "pb_self"})
62+
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
63+
param_groups = self._get_param_groups(root)
64+
self._assert_allin([pa, pc], param_groups, "default")
65+
self._assert_allin([pb], param_groups, "pb_self")
66+
assert len(param_groups["pb_member"]) == 0, param_groups
67+
68+
def test_param_overrides_self_param_group_assignment(self):
69+
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
70+
na, nb = Node(params=[pa]), Node(
71+
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
72+
)
73+
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
74+
param_groups = self._get_param_groups(root)
75+
self._assert_allin([pa, pc], param_groups, "default")
76+
self._assert_allin([pb], param_groups, "pb_self")
77+
assert len(param_groups["pb_member"]) == 0, param_groups
78+
79+
def test_no_param_groups_defined(self):
80+
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
81+
na, nb = Node(params=[pa]), Node(params=[pb])
82+
root = Node(children=[na, nb], params=[pc])
83+
param_groups = self._get_param_groups(root)
84+
self._assert_allin([pa, pb, pc], param_groups, "default")
85+
86+
def test_tree_param_groups_defined(self):
87+
"""
88+
Test generic tree assignment.
89+
90+
A0
91+
|---------------------------
92+
| | |
93+
Bb M J-
94+
|----- |-------
95+
| | | |
96+
C Ddg K Ll
97+
|--------------
98+
| | | |
99+
E4 Ff G H-
100+
101+
All nodes have one parameter. Character next to the capital
102+
letter means they have added something to their `parameter_groups`:
103+
- small letter same as capital means self is set to that letter
104+
- small letter different then capital means that member is set
105+
(the one that is named like that)
106+
- number means parameter's parameter_group is set like that
107+
- "-" means it does not have `parameter_groups` member
108+
"""
109+
p = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(12)]
110+
L = Node(params=[p[11]], param_groups={"self": "l"})
111+
K = Node(params=[p[10]], param_groups={})
112+
J = Node(params=[p[9]], param_groups=None, children=[K, L])
113+
M = Node(params=[p[8]], param_groups={})
114+
115+
E = Node(params=[p[4]], param_groups={"p0": "4"})
116+
F = Node(params=[p[5]], param_groups={"self": "f"})
117+
G = Node(params=[p[6]], param_groups={})
118+
H = Node(params=[p[7]], param_groups=None)
119+
120+
D = Node(
121+
params=[p[3]], param_groups={"self": "d", "m2": "g"}, children=[E, F, G, H]
122+
)
123+
C = Node(params=[p[2]], param_groups={})
124+
125+
B = Node(params=[p[1]], param_groups={"self": "b"}, children=[C, D])
126+
127+
A = Node(params=[p[0]], param_groups={"p0": "0"}, children=[B, M, J])
128+
129+
param_groups = self._get_param_groups(A)
130+
131+
# if parts of the group belong to two different categories assert is repeated
132+
# parameter level
133+
self._assert_allin([p[0]], param_groups, "0")
134+
self._assert_allin([p[4]], param_groups, "4")
135+
# self level
136+
self._assert_allin([p[5]], param_groups, "f")
137+
self._assert_allin([p[11]], param_groups, "l")
138+
self._assert_allin([p[2], p[1]], param_groups, "b")
139+
self._assert_allin([p[7], p[3]], param_groups, "d")
140+
# member level
141+
self._assert_allin([p[6]], param_groups, "g")
142+
# inherit level
143+
self._assert_allin([p[7], p[3]], param_groups, "d")
144+
self._assert_allin([p[2], p[1]], param_groups, "b")
145+
# default level
146+
self._assert_allin([p[8], p[9], p[10]], param_groups, "default")
147+
148+
149+
class Node(torch.nn.Module):
150+
def __init__(self, children=(), params=(), param_groups=None):
151+
super().__init__()
152+
for i, child in enumerate(children):
153+
self.add_module("m" + str(i), child)
154+
for i, param in enumerate(params):
155+
setattr(self, "p" + str(i), param)
156+
if param_groups is not None:
157+
self.param_groups = param_groups
158+
159+
def __str__(self):
160+
return (
161+
"modules:\n" + str(self._modules) + "\nparameters\n" + str(self._parameters)
162+
)

pytorch3d/implicitron/models/implicit_function/decoding_functions.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
"""
1414

1515
import logging
16+
from dataclasses import field
1617

1718
from enum import Enum
18-
from typing import Optional, Tuple
19+
from typing import Dict, Optional, Tuple
1920

2021
import torch
2122

@@ -42,8 +43,27 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module):
4243
"""
4344
Decoding function is a torch.nn.Module which takes the embedding of a location in
4445
space and transforms it into the required quantity (for example density and color).
46+
47+
Members:
48+
param_groups: dictionary where keys are names of individual parameters
49+
or module members and values are the parameter group where the
50+
parameter/member will be sorted to. "self" key is used to denote the
51+
parameter group at the module level. Possible keys, including the "self" key
52+
do not have to be defined. By default all parameters are put into "default"
53+
parameter group and have the learning rate defined in the optimizer,
54+
it can be overridden at the:
55+
- module level with “self” key, all the parameters and child
56+
module's parameters will be put to that parameter group
57+
- member level, which is the same as if the `param_groups` in that
58+
member has key=“self” and value equal to that parameter group.
59+
This is useful if members do not have `param_groups`, for
60+
example torch.nn.Linear.
61+
- parameter level, parameter with the same name as the key
62+
will be put to that parameter group.
4563
"""
4664

65+
param_groups: Dict[str, str] = field(default_factory=lambda: {})
66+
4767
def __post_init__(self):
4868
super().__init__()
4969

pytorch3d/implicitron/models/implicit_function/voxel_grid.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
808808
with mean=init_mean and std=init_std. Default 0.
809809
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
810810
will be saved as parameters and therefore be trainable. Default True.
811+
param_groups: dictionary where keys are names of individual parameters
812+
or module members and values are the parameter group where the
813+
parameter/member will be sorted to. "self" key is used to denote the
814+
parameter group at the module level. Possible keys, including the "self" key
815+
do not have to be defined. By default all parameters are put into "default"
816+
parameter group and have the learning rate defined in the optimizer,
817+
it can be overridden at the:
818+
- module level with “self” key, all the parameters and child
819+
module's parameters will be put to that parameter group
820+
- member level, which is the same as if the `param_groups` in that
821+
member has key=“self” and value equal to that parameter group.
822+
This is useful if members do not have `param_groups`, for
823+
example torch.nn.Linear.
824+
- parameter level, parameter with the same name as the key
825+
will be put to that parameter group.
811826
"""
812827

813828
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
@@ -820,6 +835,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
820835
init_mean: float = 0
821836

822837
hold_voxel_grid_as_parameters: bool = True
838+
param_groups: Dict[str, str] = field(default_factory=lambda: {})
823839

824840
def __post_init__(self):
825841
super().__init__()

tests/implicitron/test_voxel_grids.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
2020
CPFactorizedVoxelGrid,
2121
FullResolutionVoxelGrid,
22-
FullResolutionVoxelGridValues,
2322
VMFactorizedVoxelGrid,
2423
VoxelGridModule,
2524
)

0 commit comments

Comments
 (0)