mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-06-29 00:15:25 +02:00
Fin script example
This commit is contained in:
parent
96ed9fe2ae
commit
5cd50ca9f3
4 changed files with 198 additions and 103 deletions
|
@ -919,25 +919,51 @@ class Augmented_model(nn.Module):
|
|||
self._data_augmentation=mode
|
||||
self._mods['data_aug'].augment(mode)
|
||||
|
||||
#### Encapsulation Meta Opt ####
|
||||
def start_bilevel_opt(self, inner_it, hp_list, opt_param, dl_val):
|
||||
""" Set up Augmented Model for bi-level optimisation.
|
||||
|
||||
Create and keep in Augmented Model the necessary objects for meta-optimisation.
|
||||
This allow for an almost transparent use by just hiding the bi-level optimisation (see ''run_dist_dataugV3'') by ::
|
||||
|
||||
model.step(loss)
|
||||
|
||||
See ''run_simple_smartaug'' for a complete example.
|
||||
|
||||
Args:
|
||||
inner_it (int): Number of inner iteration before a meta-step. 0 inner iteration means there's no meta-step.
|
||||
hp_list (list): List of hyper-parameters to be learned.
|
||||
opt_param (dict): Dictionnary containing optimizers parameters.
|
||||
dl_val (DataLoader): Data loader of validation data.
|
||||
"""
|
||||
|
||||
self._it_count=0
|
||||
self._in_it=inner_it
|
||||
|
||||
self._opt_param=opt_param
|
||||
#Inner Opt
|
||||
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
#Validation data
|
||||
self._dl_val=dl_val
|
||||
self._dl_val_it=iter(dl_val)
|
||||
self._val_loss=0.
|
||||
|
||||
if inner_it==0 or len(hp_list)==0: #No meta-opt
|
||||
print("No meta optimization")
|
||||
|
||||
self._diffopt = model['model'].get_diffopt(
|
||||
#Inner Opt
|
||||
self._diffopt = self._mods['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
track_higher_grads=False)
|
||||
|
||||
self._meta_opt=None
|
||||
|
||||
else: #Bi-level opt
|
||||
print("Bi-Level optimization")
|
||||
self._it_count=0
|
||||
self._in_it=inner_it
|
||||
|
||||
self._opt_param=opt_param
|
||||
#Inner Opt
|
||||
inner_opt = torch.optim.SGD(self._mods['model']['original'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||
|
||||
self._diffopt = self._mods['model'].get_diffopt(
|
||||
inner_opt,
|
||||
grad_callback=(lambda grads: clip_norm(grads, max_norm=10)),
|
||||
|
@ -945,15 +971,34 @@ class Augmented_model(nn.Module):
|
|||
|
||||
#Meta Opt
|
||||
self._meta_opt = torch.optim.Adam(hp_list, lr=opt_param['Meta']['lr'])
|
||||
|
||||
self._dl_val=dl_val
|
||||
self._dl_val_it=iter(dl_val)
|
||||
self._val_loss=0.
|
||||
|
||||
self._meta_opt.zero_grad()
|
||||
|
||||
def step(self, loss):
|
||||
""" Perform a model update.
|
||||
|
||||
''start_bilevel_opt'' method needs to be called once before using this method.
|
||||
|
||||
Perform a step of inner optimization and, if needed, a step of meta optimization.
|
||||
Replace ::
|
||||
|
||||
opt.zero_grad()
|
||||
loss.backward()
|
||||
opt.step()
|
||||
|
||||
val_loss=...
|
||||
val_loss.backward()
|
||||
meta_opt.step()
|
||||
adjust_param()
|
||||
detach()
|
||||
meta_opt.zero_grad()
|
||||
|
||||
By ::
|
||||
|
||||
model.step(loss)
|
||||
|
||||
Args:
|
||||
loss (Tensor): the training loss tensor.
|
||||
"""
|
||||
self._it_count+=1
|
||||
self._diffopt.step(loss) #(opt.zero_grad, loss.backward, opt.step)
|
||||
|
||||
|
@ -982,6 +1027,22 @@ class Augmented_model(nn.Module):
|
|||
|
||||
self._it_count=0
|
||||
|
||||
def val_loss(self):
|
||||
""" Get the validation loss.
|
||||
|
||||
Compute, if needed, the validation loss and returns it.
|
||||
|
||||
''start_bilevel_opt'' method needs to be called once before using this method.
|
||||
|
||||
Returns:
|
||||
(Tensor) Validation loss on a single batch of data.
|
||||
"""
|
||||
if(self._meta_opt): #Bilevel opti
|
||||
return self._val_loss
|
||||
else:
|
||||
return compute_vaLoss(model=self._mods['model'], dl_it=self._dl_val_it, dl=self._dl_val)
|
||||
|
||||
##########################
|
||||
|
||||
def train(self, mode=True):
|
||||
""" Set the module training mode.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue