mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 20:20:46 +02:00
Ajout plus de controle/Vision sur les optimizers
This commit is contained in:
parent
d1ee0c632e
commit
41c7273241
3 changed files with 49 additions and 23 deletions
|
@ -323,7 +323,7 @@ class Bottleneck(nn.Module):
|
||||||
#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2]
|
#ResNet18 : block=BasicBlock, layers=[2, 2, 2, 2]
|
||||||
class ResNet(nn.Module):
|
class ResNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
def __init__(self, block=BasicBlock, layers=[2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
|
||||||
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
||||||
norm_layer=None):
|
norm_layer=None):
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
|
@ -419,11 +419,14 @@ class ResNet(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self._forward_impl(x)
|
return self._forward_impl(x)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "ResNet18"
|
||||||
|
|
||||||
## Wide ResNet ##
|
## Wide ResNet ##
|
||||||
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
#https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py
|
||||||
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
#https://github.com/arcelien/pba/blob/master/pba/wrn.py
|
||||||
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
|
#https://github.com/szagoruyko/wide-residual-networks/blob/master/pytorch/resnet.py
|
||||||
|
'''
|
||||||
class BasicBlock(nn.Module):
|
class BasicBlock(nn.Module):
|
||||||
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
|
||||||
super(BasicBlock, self).__init__()
|
super(BasicBlock, self).__init__()
|
||||||
|
@ -516,3 +519,4 @@ class WideResNet(nn.Module):
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
|
return "WideResNet(s{}-d{})".format(self.kernel_size, self.depth)
|
||||||
|
'''
|
|
@ -65,16 +65,28 @@ else:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
tasks={
|
tasks={
|
||||||
#'classic',
|
'classic',
|
||||||
'aug_dataset',
|
#'aug_dataset',
|
||||||
#'aug_model'
|
#'aug_model'
|
||||||
}
|
}
|
||||||
n_inner_iter = 1
|
n_inner_iter = 1
|
||||||
epochs = 150
|
epochs = 100
|
||||||
dataug_epoch_start=0
|
dataug_epoch_start=0
|
||||||
|
optim_param={
|
||||||
|
'Meta':{
|
||||||
|
'optim':'Adam',
|
||||||
|
'lr':1e-2, #1e-2
|
||||||
|
},
|
||||||
|
'Inner':{
|
||||||
|
'optim': 'SGD',
|
||||||
|
'lr':1e-2, #1e-2
|
||||||
|
'momentum':0.9, #0.9
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
model = LeNet(3,10)
|
#model = LeNet(3,10)
|
||||||
#model = MobileNetV2(num_classes=10)
|
#model = MobileNetV2(num_classes=10)
|
||||||
|
model = ResNet(num_classes=10)
|
||||||
#model = WideResNet(num_classes=10, wrn_size=32)
|
#model = WideResNet(num_classes=10, wrn_size=32)
|
||||||
|
|
||||||
#### Classic ####
|
#### Classic ####
|
||||||
|
@ -83,14 +95,14 @@ if __name__ == "__main__":
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
||||||
log= train_classic(model=model, epochs=epochs, print_freq=1)
|
log= train_classic(model=model, opt_param=optim_param, epochs=epochs, print_freq=1)
|
||||||
#log= train_classic_higher(model=model, epochs=epochs)
|
#log= train_classic_higher(model=model, epochs=epochs)
|
||||||
|
|
||||||
exec_time=time.process_time() - t0
|
exec_time=time.process_time() - t0
|
||||||
####
|
####
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param['Inner'], "Device": device_name, "Log": log}
|
||||||
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
filename = "{}-{} epochs".format(str(model),epochs)
|
filename = "{}-{} epochs".format(str(model),epochs)
|
||||||
with open("res/log/%s.json" % filename, "w+") as f:
|
with open("res/log/%s.json" % filename, "w+") as f:
|
||||||
|
@ -123,7 +135,7 @@ if __name__ == "__main__":
|
||||||
##log= train_classic_higher(model=model, epochs=epochs)
|
##log= train_classic_higher(model=model, epochs=epochs)
|
||||||
|
|
||||||
data_train_aug = AugmentedDatasetV2("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
data_train_aug = AugmentedDatasetV2("./data", train=True, download=download_data, transform=transform, subset=(0,int(len(data_train)/2)))
|
||||||
data_train_aug.augement_data(aug_copy=10)
|
data_train_aug.augement_data(aug_copy=1)
|
||||||
print(data_train_aug)
|
print(data_train_aug)
|
||||||
unsup_ratio = 5
|
unsup_ratio = 5
|
||||||
dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True)
|
dl_unsup = torch.utils.data.DataLoader(data_train_aug, batch_size=BATCH_SIZE*unsup_ratio, shuffle=True)
|
||||||
|
@ -135,13 +147,13 @@ if __name__ == "__main__":
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
print("{} on {} for {} epochs".format(str(model), device_name, epochs))
|
||||||
log= train_UDA(model=model, dl_unsup=dl_unsup, epochs=epochs, print_freq=10)
|
log= train_UDA(model=model, dl_unsup=dl_unsup, epochs=epochs, opt_param=optim_param, print_freq=10)
|
||||||
|
|
||||||
exec_time=time.process_time() - t0
|
exec_time=time.process_time() - t0
|
||||||
####
|
####
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param['Inner'], "Device": device_name, "Log": log}
|
||||||
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
print(str(model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
filename = "{}-{}-{} epochs".format(str(data_train_aug),str(model),epochs)
|
filename = "{}-{}-{} epochs".format(str(data_train_aug),str(model),epochs)
|
||||||
with open("res/log/%s.json" % filename, "w+") as f:
|
with open("res/log/%s.json" % filename, "w+") as f:
|
||||||
|
@ -164,13 +176,20 @@ if __name__ == "__main__":
|
||||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), model).to(device)
|
||||||
|
|
||||||
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
print("{} on {} for {} epochs - {} inner_it".format(str(aug_model), device_name, epochs, n_inner_iter))
|
||||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, KLdiv=False, loss_patience=None)
|
log= run_dist_dataugV2(model=aug_model,
|
||||||
|
epochs=epochs,
|
||||||
|
inner_it=n_inner_iter,
|
||||||
|
dataug_epoch_start=dataug_epoch_start,
|
||||||
|
opt_param=optim_param,
|
||||||
|
print_freq=10,
|
||||||
|
KLdiv=True,
|
||||||
|
loss_patience=None)
|
||||||
|
|
||||||
exec_time=time.process_time() - t0
|
exec_time=time.process_time() - t0
|
||||||
####
|
####
|
||||||
print('-'*9)
|
print('-'*9)
|
||||||
times = [x["time"] for x in log]
|
times = [x["time"] for x in log]
|
||||||
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), "Device": device_name, "Param_names": aug_model.TF_names(), "Log": log}
|
out = {"Accuracy": max([x["acc"] for x in log]), "Time": (np.mean(times),np.std(times), exec_time), 'Optimizer': optim_param, "Device": device_name, "Log": log}
|
||||||
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
print(str(aug_model),": acc", out["Accuracy"], "in:", out["Time"][0], "+/-", out["Time"][1])
|
||||||
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
filename = "{}-{} epochs (dataug:{})- {} in_it".format(str(aug_model),epochs,dataug_epoch_start,n_inner_iter)
|
||||||
with open("res/log/%s.json" % filename, "w+") as f:
|
with open("res/log/%s.json" % filename, "w+") as f:
|
||||||
|
|
|
@ -47,10 +47,10 @@ def compute_vaLoss(model, dl_it, dl):
|
||||||
|
|
||||||
return F.cross_entropy(model(xs), ys)
|
return F.cross_entropy(model(xs), ys)
|
||||||
|
|
||||||
def train_classic(model, epochs=1, print_freq=1):
|
def train_classic(model, opt_param, epochs=1, print_freq=1):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
optim = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
dl_val_it = iter(dl_val)
|
dl_val_it = iter(dl_val)
|
||||||
|
@ -305,11 +305,12 @@ def train_classic_tests(model, epochs=1):
|
||||||
print("Copy ", countcopy)
|
print("Copy ", countcopy)
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def train_UDA(model, dl_unsup, epochs=1, print_freq=1):
|
def train_UDA(model, dl_unsup, opt_param, epochs=1, print_freq=1):
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
#opt = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||||
optim = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
|
opt = torch.optim.SGD(model.parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
dl_val_it = iter(dl_val)
|
dl_val_it = iter(dl_val)
|
||||||
|
@ -340,14 +341,13 @@ def train_UDA(model, dl_unsup, epochs=1, print_freq=1):
|
||||||
sup_logits = model.forward(origin_xs)
|
sup_logits = model.forward(origin_xs)
|
||||||
unsup_logits = model.forward(aug_xs)
|
unsup_logits = model.forward(aug_xs)
|
||||||
|
|
||||||
#print(unsup_logits.shape, sup_logits.shape)
|
|
||||||
log_sup=F.log_softmax(sup_logits, dim=1)
|
log_sup=F.log_softmax(sup_logits, dim=1)
|
||||||
log_unsup=F.log_softmax(unsup_logits, dim=1)
|
log_unsup=F.log_softmax(unsup_logits, dim=1)
|
||||||
#KL div w/ logits
|
#KL div w/ logits
|
||||||
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
|
unsup_loss = F.softmax(sup_logits, dim=1)*(log_sup-log_unsup)
|
||||||
unsup_loss=unsup_loss.sum(dim=-1).mean()
|
unsup_loss=unsup_loss.sum(dim=-1).mean()
|
||||||
|
|
||||||
#print(unsup_loss.shape)
|
#print(unsup_loss)
|
||||||
unsupp_coeff = 1
|
unsupp_coeff = 1
|
||||||
loss = sup_loss + unsup_loss * unsupp_coeff
|
loss = sup_loss + unsup_loss * unsupp_coeff
|
||||||
|
|
||||||
|
@ -629,7 +629,7 @@ def run_dist_dataug(model, epochs=1, inner_it=1, dataug_epoch_start=0):
|
||||||
print("Copy ", countcopy)
|
print("Copy ", countcopy)
|
||||||
return log
|
return log
|
||||||
|
|
||||||
def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
def run_dist_dataugV2(model, opt_param, epochs=1, inner_it=0, dataug_epoch_start=0, print_freq=1, KLdiv=False, loss_patience=None, save_sample=False):
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
log = []
|
log = []
|
||||||
countcopy=0
|
countcopy=0
|
||||||
|
@ -637,8 +637,8 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
dl_val_it = iter(dl_val)
|
dl_val_it = iter(dl_val)
|
||||||
|
|
||||||
#if inner_it!=0:
|
#if inner_it!=0:
|
||||||
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=1e-2) #lr=1e-2
|
meta_opt = torch.optim.Adam(model['data_aug'].parameters(), lr=opt_param['Meta']['lr']) #lr=1e-2
|
||||||
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=1e-2, momentum=0.9)
|
inner_opt = torch.optim.SGD(model['model'].parameters(), lr=opt_param['Inner']['lr'], momentum=opt_param['Inner']['momentum']) #lr=1e-2 / momentum=0.9
|
||||||
|
|
||||||
high_grad_track = True
|
high_grad_track = True
|
||||||
if inner_it == 0:
|
if inner_it == 0:
|
||||||
|
@ -703,7 +703,10 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
||||||
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
#aug_loss = F.kl_div(aug_logits, sup_logits, reduction='none') #Similarite predictions (distributions)
|
||||||
|
|
||||||
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
w_loss = fmodel['data_aug'].loss_weight() #Weight loss
|
||||||
aug_loss = (w_loss * aug_loss).mean()
|
aug_loss = (w_loss * aug_loss).mean() #apprentissage differe ?
|
||||||
|
|
||||||
|
aug_loss += (F.cross_entropy(log_aug, ys , reduction='none') * w_loss).mean()
|
||||||
|
#print(aug_loss)
|
||||||
unsupp_coeff = 1
|
unsupp_coeff = 1
|
||||||
loss += aug_loss * unsupp_coeff
|
loss += aug_loss * unsupp_coeff
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue