1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
class AveragedModel(Module):
"""
You can also use custom averaging functions with `avg_fn` parameter.
If no averaging function is provided, the default is to compute
equally-weighted average of the weights.
"""
def __init__(self, model, device=None, avg_fn=None, use_buffers=False):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.register_buffer('n_averaged',
torch.tensor(0, dtype=torch.long, device=device))
if avg_fn is None:
def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
return averaged_model_parameter + \
(model_parameter - averaged_model_parameter) / (num_averaged + 1)
self.avg_fn = avg_fn
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
self_param = (
itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers else self.parameters()
)
model_param = (
itertools.chain(model.parameters(), model.buffers())
if self.use_buffers else model.parameters()
)
for p_swa, p_model in zip(self_param, model_param):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
self.n_averaged += 1
@torch.no_grad()
def update_bn(loader, model, device=None):
r"""Updates BatchNorm running_mean, running_var buffers in the model.
It performs one pass over data in `loader` to estimate the activation
statistics for BatchNorm layers in the model.
Args:
loader (torch.utils.data.DataLoader): dataset loader to compute the
activation statistics on. Each data batch should be either a
tensor, or a list/tuple whose first element is a tensor
containing data.
model (torch.nn.Module): model for which we seek to update BatchNorm
statistics.
device (torch.device, optional): If set, data will be transferred to
:attr:`device` before being passed into :attr:`model`.
Example:
>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)
.. note::
The `update_bn` utility assumes that each data batch in :attr:`loader`
is either a tensor or a list or tuple of tensors; in the latter case it
is assumed that :meth:`model.forward()` should be called on the first
element of the list or tuple corresponding to the data batch.
"""
momenta = {}
for module in model.modules():
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module.running_mean = torch.zeros_like(module.running_mean)
module.running_var = torch.ones_like(module.running_var)
momenta[module] = module.momentum
if not momenta:
return
was_training = model.training
model.train()
for module in momenta.keys():
module.momentum = None
module.num_batches_tracked *= 0
for input in loader:
if isinstance(input, (list, tuple)):
input = input[0]
if device is not None:
input = input.to(device)
model(input)
for bn_module in momenta.keys():
bn_module.momentum = momenta[bn_module]
model.train(was_training)
|