mirror of
https://github.com/AntoineHX/smart_augmentation.git
synced 2025-05-04 12:10:45 +02:00
Modif Dataugv6
This commit is contained in:
parent
ebee1b789f
commit
3ec99bf729
6 changed files with 334 additions and 36 deletions
|
@ -2,12 +2,12 @@ from utils import *
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
#'''
|
||||
'''
|
||||
files=[
|
||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
#"res/good_TF_tests/log/Aug_mod(Data_augV5(Uniform-14TFx2-MagFxSh)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
#"res/brutus-tests/log/Aug_mod(Data_augV5(Mix0.5-14TFx1-Mag)-LeNet)-150epochs(dataug:0)-1in_it-0.json",
|
||||
"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
#"res/log/Aug_mod(RandAugUDA(18TFx2-Mag1)-LeNet)-100 epochs (dataug:0)- 0 in_it.json",
|
||||
]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
|
@ -16,7 +16,7 @@ if __name__ == "__main__":
|
|||
data = json.load(json_file)
|
||||
plot_resV2(data['Log'], fig_name=file.replace('.json','').replace('log/',''), param_names=data['Param_names'])
|
||||
#plot_TF_influence(data['Log'], param_names=data['Param_names'])
|
||||
#'''
|
||||
'''
|
||||
## Loss , Acc, Proba = f(epoch) ##
|
||||
#plot_compare(filenames=files, fig_name="res/compare")
|
||||
|
||||
|
@ -76,11 +76,11 @@ if __name__ == "__main__":
|
|||
'''
|
||||
|
||||
#Res print
|
||||
'''
|
||||
#'''
|
||||
nb_run=3
|
||||
accs = []
|
||||
times = []
|
||||
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Mix1.0-14TFx2-MagFxSh)-LeNet)-150epochs(dataug:0)-10in_it-%s.json"%str(run) for run in range(nb_run)]
|
||||
files = ["res/brutus-tests/log/Aug_mod(Data_augV5(Mix1-14TFx4-Mag)-LeNet)-150epochs(dataug:0)-1in_it-%s.json"%str(run) for run in range(nb_run)]
|
||||
|
||||
for idx, file in enumerate(files):
|
||||
#legend+=str(idx)+'-'+file+'\n'
|
||||
|
@ -90,4 +90,4 @@ if __name__ == "__main__":
|
|||
times.append(data['Time'][0])
|
||||
|
||||
print(files[0], np.mean(accs), np.std(accs), np.mean(times))
|
||||
'''
|
||||
#'''
|
|
@ -28,9 +28,105 @@ data_test = torchvision.datasets.MNIST(
|
|||
"./data", train=False, download=True, transform=torchvision.transforms.ToTensor()
|
||||
)
|
||||
'''
|
||||
data_train = torchvision.datasets.CIFAR10(
|
||||
"./data", train=True, download=True, transform=transform
|
||||
)
|
||||
|
||||
from torchvision.datasets.vision import VisionDataset
|
||||
from PIL import Image
|
||||
import augmentation_transforms
|
||||
import numpy as np
|
||||
|
||||
class AugmentedDataset(VisionDataset):
|
||||
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
|
||||
|
||||
super(AugmentedDataset, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
|
||||
supervised_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download, transform=transform)
|
||||
|
||||
self.sup_data = supervised_dataset.data
|
||||
self.sup_targets = supervised_dataset.targets
|
||||
|
||||
for idx, img in enumerate(self.sup_data):
|
||||
self.sup_data[idx]= Image.fromarray(img) #to PIL Image
|
||||
|
||||
self.unsup_data=[]
|
||||
self.unsup_targets=[]
|
||||
|
||||
self.data= self.sup_data
|
||||
self.targets= self.sup_targets
|
||||
|
||||
|
||||
self._TF = [
|
||||
'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize',
|
||||
'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate',
|
||||
'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness']
|
||||
self._op_list =[]
|
||||
self.prob=0.5
|
||||
for tf in self._TF:
|
||||
for mag in range(1, 10):
|
||||
self._op_list+=[(tf, self.prob, mag)]
|
||||
self._nb_op = len(self._op_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target class.
|
||||
"""
|
||||
img, target = self.data[index], self.targets[index]
|
||||
|
||||
# doing this so that it is consistent with all other datasets
|
||||
# to return a PIL Image
|
||||
#img = Image.fromarray(img)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return img, target
|
||||
|
||||
def augement_data(self, aug_copy=1):
|
||||
|
||||
policies = []
|
||||
for op_1 in self._op_list:
|
||||
for op_2 in self._op_list:
|
||||
policies += [[op_1, op_2]]
|
||||
|
||||
for idx, image in enumerate(self.sup_data):
|
||||
for _ in range(aug_copy):
|
||||
chosen_policy = policies[np.random.choice(len(policies))]
|
||||
aug_image = augmentation_transforms.apply_policy(chosen_policy, image)
|
||||
#aug_image = augmentation_transforms.cutout_numpy(aug_image)
|
||||
|
||||
self.unsup_data+=[aug_image]
|
||||
self.unsup_targets+=[self.sup_targets[idx]]
|
||||
|
||||
print(type(self.data), type(self.sup_data), type(self.unsup_data))
|
||||
print(len(self.data), len(self.sup_data), len(self.unsup_data))
|
||||
#self.data= self.sup_data+self.unsup_data
|
||||
self.data= np.concatenate((self.sup_data, self.unsup_data), axis=0)
|
||||
print(len(self.data))
|
||||
self.targets= self.sup_targets+self.unsup_targets
|
||||
|
||||
|
||||
def len_supervised(self):
|
||||
return len(self.sup_data)
|
||||
|
||||
def len_unsupervised(self):
|
||||
return len(self.unsup_data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
data_train = torchvision.datasets.CIFAR10("./data", train=True, download=True, transform=transform)
|
||||
#print(len(data_train))
|
||||
#data_train = AugmentedDataset("./data", train=True, download=True, transform=transform)
|
||||
#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised())
|
||||
#data_train.augement_data()
|
||||
#print(len(data_train), data_train.len_supervised(), data_train.len_unsupervised())
|
||||
#data_val = torchvision.datasets.CIFAR10(
|
||||
# "./data", train=True, download=True, transform=transform
|
||||
#)
|
||||
|
@ -45,4 +141,4 @@ val_subset_indices=range(int(len(data_train)/2),len(data_train))
|
|||
|
||||
dl_train = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(train_subset_indices))
|
||||
dl_val = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE, shuffle=False, sampler=SubsetRandomSampler(val_subset_indices))
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
dl_test = torch.utils.data.DataLoader(data_test, batch_size=TEST_SIZE, shuffle=False)
|
||||
|
|
202
higher/dataug.py
202
higher/dataug.py
|
@ -692,6 +692,208 @@ class Data_augV5(nn.Module): #Optimisation jointe (mag, proba)
|
|||
else:
|
||||
return "Data_augV5(Mix%.1f%s-%dTFx%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._N_seqTF, mag_param)
|
||||
|
||||
import numpy as np
|
||||
class Data_augV6(nn.Module): #Optimisation sequentielle
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True):
|
||||
super(Data_augV6, self).__init__()
|
||||
assert len(TF_dict)>0
|
||||
|
||||
self._data_augmentation = True
|
||||
|
||||
self._TF_dict = TF_dict
|
||||
self._TF= list(self._TF_dict.keys())
|
||||
self._nb_tf= len(self._TF)
|
||||
|
||||
self._N_seqTF = N_TF
|
||||
self._shared_mag = shared_mag
|
||||
self._fixed_mag = fixed_mag
|
||||
|
||||
self._TF_set_size=3
|
||||
#if self._TF_set_size>self._nb_tf:
|
||||
# print("Warning : TF sets size higher than number of TF. Reducing set size to %d"%self._nb_tf)
|
||||
# self._TF_set_size=self._nb_tf
|
||||
assert self._nb_tf>=self._TF_set_size
|
||||
self._TF_sets=[]
|
||||
for i in range(1,self._nb_tf):
|
||||
for j in range(i,self._nb_tf):
|
||||
if i!=j:
|
||||
self._TF_sets+=[torch.tensor([0, i, j])]
|
||||
#print(self._TF_sets)
|
||||
#self._TF_sets=[torch.tensor([0, i, j]) for i in range(1,self._nb_tf)] #All VS Identity
|
||||
self._TF_schedule = [list(range(len(self._TF_sets))) for _ in range(self._N_seqTF)]
|
||||
for n_tf in range(self._N_seqTF) :
|
||||
TF.random.shuffle(self._TF_schedule[n_tf])
|
||||
#print(self._TF_schedule)
|
||||
self._current_TF_idx=0 #random.randint
|
||||
self._start_prob = 1/self._TF_set_size
|
||||
|
||||
|
||||
self._params = nn.ParameterDict({
|
||||
"prob": nn.Parameter(torch.tensor(self._start_prob).expand(self._nb_tf)), #Proba independantes
|
||||
"mag" : nn.Parameter(torch.tensor(float(TF.PARAMETER_MAX)) if self._shared_mag
|
||||
else torch.tensor(float(TF.PARAMETER_MAX)).expand(self._nb_tf)), #[0, PARAMETER_MAX]
|
||||
})
|
||||
|
||||
#for t in TF.TF_no_mag: self._params['mag'][self._TF.index(t)].data-=self._params['mag'][self._TF.index(t)].data #Mag inutile pour les TF ignore_mag
|
||||
|
||||
#Distribution
|
||||
self._fixed_prob=fixed_prob
|
||||
self._samples = []
|
||||
self._mix_dist = False
|
||||
if mix_dist != 0.0:
|
||||
self._mix_dist = True
|
||||
self._mix_factor = max(min(mix_dist, 1.0), 0.0)
|
||||
|
||||
#Mag regularisation
|
||||
if not self._fixed_mag:
|
||||
if self._shared_mag :
|
||||
self._reg_tgt = torch.tensor(TF.PARAMETER_MAX, dtype=torch.float) #Encourage amplitude max
|
||||
else:
|
||||
self._reg_mask=[self._TF.index(t) for t in self._TF if t not in TF.TF_ignore_mag]
|
||||
self._reg_tgt=torch.full(size=(len(self._reg_mask),), fill_value=TF.PARAMETER_MAX) #Encourage amplitude max
|
||||
|
||||
def forward(self, x):
|
||||
self._samples = []
|
||||
if self._data_augmentation:# and TF.random.random() < 0.5:
|
||||
device = x.device
|
||||
batch_size, h, w = x.shape[0], x.shape[2], x.shape[3]
|
||||
|
||||
x = copy.deepcopy(x) #Evite de modifier les echantillons par reference (Problematique pour des utilisations paralleles)
|
||||
|
||||
for n_tf in range(self._N_seqTF):
|
||||
|
||||
tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(device)
|
||||
#print(n_tf, tf_set)
|
||||
## Echantillonage ##
|
||||
uniforme_dist = torch.ones(1,len(tf_set),device=device).softmax(dim=1)
|
||||
|
||||
if not self._mix_dist:
|
||||
self._distrib = uniforme_dist
|
||||
else:
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
curr_prob = torch.index_select(prob, 0, tf_set)
|
||||
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
|
||||
self._distrib = (self._mix_factor*curr_prob+(1-self._mix_factor)*uniforme_dist).softmax(dim=1) #Mix distrib reel / uniforme avec mix_factor
|
||||
|
||||
cat_distrib= Categorical(probs=torch.ones((batch_size, len(tf_set)), device=device)*self._distrib)
|
||||
sample = cat_distrib.sample()
|
||||
self._samples.append(sample)
|
||||
|
||||
## Transformations ##
|
||||
x = self.apply_TF(x, sample)
|
||||
return x
|
||||
|
||||
def apply_TF(self, x, sampled_TF):
|
||||
device = x.device
|
||||
batch_size, channels, h, w = x.shape
|
||||
smps_x=[]
|
||||
|
||||
for sel_idx, tf_idx in enumerate(self._TF_sets[self._current_TF_idx]):
|
||||
mask = sampled_TF==sel_idx #Create selection mask
|
||||
smp_x = x[mask] #torch.masked_select() ? (NEcessite d'expand le mask au meme dim)
|
||||
|
||||
if smp_x.shape[0]!=0: #if there's data to TF
|
||||
magnitude=self._params["mag"] if self._shared_mag else self._params["mag"][tf_idx]
|
||||
if self._fixed_mag: magnitude=magnitude.detach() #Fmodel tente systematiquement de tracker les gradient de tout les param
|
||||
|
||||
tf=self._TF[tf_idx]
|
||||
#print(magnitude)
|
||||
|
||||
#In place
|
||||
#x[mask]=self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
|
||||
#Out of place
|
||||
smp_x = self._TF_dict[tf](x=smp_x, mag=magnitude)
|
||||
idx= mask.nonzero()
|
||||
idx= idx.expand(-1,channels).unsqueeze(dim=2).expand(-1,channels, h).unsqueeze(dim=3).expand(-1,channels, h, w) #Il y a forcement plus simple ...
|
||||
x=x.scatter(dim=0, index=idx, src=smp_x)
|
||||
|
||||
return x
|
||||
|
||||
def adjust_param(self, soft=False): #Detach from gradient ?
|
||||
if not self._fixed_prob:
|
||||
if soft :
|
||||
self._params['prob'].data=F.softmax(self._params['prob'].data, dim=0) #Trop 'soft', bloque en dist uniforme si lr trop faible
|
||||
else:
|
||||
self._params['prob'].data = F.relu(self._params['prob'].data)
|
||||
#self._params['prob'].data = self._params['prob'].clamp(min=0.0,max=1.0)
|
||||
#self._params['prob'].data = self._params['prob']/sum(self._params['prob']) #Contrainte sum(p)=1
|
||||
|
||||
self._params['prob'].data[0]=self._start_prob #Fixe p identite
|
||||
|
||||
if not self._fixed_mag:
|
||||
#self._params['mag'].data = self._params['mag'].data.clamp(min=0.0,max=TF.PARAMETER_MAX) #Bloque une fois au extreme
|
||||
self._params['mag'].data = F.relu(self._params['mag'].data) - F.relu(self._params['mag'].data - TF.PARAMETER_MAX)
|
||||
|
||||
def loss_weight(self): #A verifier
|
||||
if len(self._samples)==0 : return 1 #Pas d'echantillon = pas de ponderation
|
||||
|
||||
prob = self._params["prob"].detach() if self._fixed_prob else self._params["prob"]
|
||||
|
||||
#Plusieurs TF sequentielles (Attention ne prend pas en compte ordre !)
|
||||
w_loss = torch.zeros((self._samples[0].shape[0],self._TF_set_size), device=self._samples[0].device)
|
||||
for n_tf in range(self._N_seqTF):
|
||||
tmp_w = torch.zeros(w_loss.size(),device=w_loss.device)
|
||||
tmp_w.scatter_(dim=1, index=self._samples[n_tf].view(-1,1), value=1/self._N_seqTF)
|
||||
|
||||
tf_set = self._TF_sets[self._TF_schedule[n_tf][self._current_TF_idx]].to(prob.device)
|
||||
curr_prob = torch.index_select(prob, 0, tf_set)
|
||||
curr_prob = curr_prob /sum(curr_prob) #Contrainte sum(p)=1
|
||||
|
||||
#ATTENTION DISTRIB DIFFERENTE AVEC MIX
|
||||
assert not self._mix_dist
|
||||
w_loss += tmp_w * curr_prob /self._distrib #Ponderation par les proba (divisee par la distrib pour pas diminuer la loss)
|
||||
|
||||
w_loss = torch.sum(w_loss,dim=1)
|
||||
return w_loss
|
||||
|
||||
def reg_loss(self, reg_factor=0.005):
|
||||
if self._fixed_mag: # or self._fixed_prob: #Pas de regularisation si trop peu de DOF
|
||||
return torch.tensor(0)
|
||||
else:
|
||||
#return reg_factor * F.l1_loss(self._params['mag'][self._reg_mask], target=self._reg_tgt, reduction='mean')
|
||||
params = self._params['mag'] if self._params['mag'].shape==torch.Size([]) else self._params['mag'][self._reg_mask]
|
||||
return reg_factor * F.mse_loss(params, target=self._reg_tgt.to(params.device), reduction='mean')
|
||||
|
||||
def next_TF_set(self, idx=None):
|
||||
if idx:
|
||||
self._current_TF_idx=idx
|
||||
else:
|
||||
self._current_TF_idx+=1
|
||||
if self._current_TF_idx== len(self._TF_schedule[0]):
|
||||
self._current_TF_idx=0
|
||||
#for n_tf in range(self._N_seqTF) :
|
||||
# TF.random.shuffle(self._TF_schedule[n_tf])
|
||||
#print(self._TF_schedule)
|
||||
#print("Current TF :",self._TF_sets[self._current_TF_idx])
|
||||
|
||||
def train(self, mode=None):
|
||||
if mode is None :
|
||||
mode=self._data_augmentation
|
||||
self.augment(mode=mode) #Inutile si mode=None
|
||||
super(Data_augV6, self).train(mode)
|
||||
|
||||
def eval(self):
|
||||
self.train(mode=False)
|
||||
|
||||
def augment(self, mode=True):
|
||||
self._data_augmentation=mode
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._params[key]
|
||||
|
||||
def __str__(self):
|
||||
dist_param=''
|
||||
if self._fixed_prob: dist_param+='Fx'
|
||||
mag_param='Mag'
|
||||
if self._fixed_mag: mag_param+= 'Fx'
|
||||
if self._shared_mag: mag_param+= 'Sh'
|
||||
if not self._mix_dist:
|
||||
return "Data_augV6(Uniform%s-%dTF(%d)x%d-%s)" % (dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param)
|
||||
else:
|
||||
return "Data_augV6(Mix%.1f%s-%dTF(%d)x%d-%s)" % (self._mix_factor,dist_param, self._nb_tf, self._TF_set_size, self._N_seqTF, mag_param)
|
||||
|
||||
|
||||
class RandAug(nn.Module): #RandAugment = UniformFx-MagFxSh + rapide
|
||||
def __init__(self, TF_dict=TF.TF_dict, N_TF=1, mag=TF.PARAMETER_MAX):
|
||||
super(RandAug, self).__init__()
|
||||
|
|
|
@ -5,21 +5,21 @@ from train_utils import *
|
|||
|
||||
tf_names = [
|
||||
## Geometric TF ##
|
||||
#'Identity',
|
||||
#'FlipUD',
|
||||
#'FlipLR',
|
||||
#'Rotate',
|
||||
#'TranslateX',
|
||||
#'TranslateY',
|
||||
#'ShearX',
|
||||
#'ShearY',
|
||||
'Identity',
|
||||
'FlipUD',
|
||||
'FlipLR',
|
||||
'Rotate',
|
||||
'TranslateX',
|
||||
'TranslateY',
|
||||
'ShearX',
|
||||
'ShearY',
|
||||
|
||||
## Color TF (Expect image in the range of [0, 1]) ##
|
||||
#'Contrast',
|
||||
#'Color',
|
||||
#'Brightness',
|
||||
#'Sharpness',
|
||||
#'Posterize',
|
||||
'Contrast',
|
||||
'Color',
|
||||
'Brightness',
|
||||
'Sharpness',
|
||||
'Posterize',
|
||||
'Solarize', #=>Image entre [0,1] #Pas opti pour des batch
|
||||
|
||||
#Color TF (Common mag scale)
|
||||
|
@ -44,10 +44,10 @@ tf_names = [
|
|||
#'BadTranslateY',
|
||||
#'BadTranslateY_neg',
|
||||
|
||||
#'BadColor',
|
||||
#'BadSharpness',
|
||||
#'BadContrast',
|
||||
#'BadBrightness',
|
||||
'BadColor',
|
||||
'BadSharpness',
|
||||
'BadContrast',
|
||||
'BadBrightness',
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast', #Pas opti pour des batch (Super lent)
|
||||
|
@ -65,7 +65,7 @@ else:
|
|||
if __name__ == "__main__":
|
||||
|
||||
n_inner_iter = 10
|
||||
epochs = 1
|
||||
epochs = 100
|
||||
dataug_epoch_start=0
|
||||
|
||||
#### Classic ####
|
||||
|
@ -95,12 +95,12 @@ if __name__ == "__main__":
|
|||
t0 = time.process_time()
|
||||
tf_dict = {k: TF.TF_dict[k] for k in tf_names}
|
||||
#tf_dict = TF.TF_dict
|
||||
aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=1, mix_dist=0.0, fixed_prob=True, fixed_mag=False, shared_mag=True), LeNet(3,10)).to(device)
|
||||
aug_model = Augmented_model(Data_augV6(TF_dict=tf_dict, N_TF=3, mix_dist=0.0, fixed_prob=False, fixed_mag=True, shared_mag=True), LeNet(3,10)).to(device)
|
||||
#aug_model = Augmented_model(Data_augV5(TF_dict=tf_dict, N_TF=2, mix_dist=0.5, fixed_mag=True, shared_mag=True), WideResNet(num_classes=10, wrn_size=160)).to(device)
|
||||
#aug_model = Augmented_model(RandAug(TF_dict=tf_dict, N_TF=2), LeNet(3,10)).to(device)
|
||||
print(str(aug_model), 'on', device_name)
|
||||
#run_simple_dataug(inner_it=n_inner_iter, epochs=epochs)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=1, loss_patience=None)
|
||||
log= run_dist_dataugV2(model=aug_model, epochs=epochs, inner_it=n_inner_iter, dataug_epoch_start=dataug_epoch_start, print_freq=10, loss_patience=None)
|
||||
|
||||
####
|
||||
print('-'*9)
|
||||
|
|
|
@ -618,6 +618,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
|
||||
meta_opt.step()
|
||||
model['data_aug'].adjust_param(soft=False) #Contrainte sum(proba)=1
|
||||
model['data_aug'].next_TF_set()
|
||||
|
||||
fmodel = higher.patch.monkeypatch(model, device=None, copy_initial_weights=True)
|
||||
diffopt = higher.optim.get_diff_optim(inner_opt, model.parameters(),fmodel=fmodel, track_higher_grads=high_grad_track)
|
||||
|
@ -651,7 +652,7 @@ def run_dist_dataugV2(model, epochs=1, inner_it=0, dataug_epoch_start=0, print_f
|
|||
print('TF Proba :', model['data_aug']['prob'].data)
|
||||
#print('proba grad',model['data_aug']['prob'].grad)
|
||||
print('TF Mag :', model['data_aug']['mag'].data)
|
||||
print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Mag grad',model['data_aug']['mag'].grad)
|
||||
#print('Reg loss:', model['data_aug'].reg_loss().item())
|
||||
#############
|
||||
#### Log ####
|
||||
|
|
|
@ -46,7 +46,7 @@ TF_dict={ #Dataugv5 #AutoAugment
|
|||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
#Non fonctionnel
|
||||
#'Auto_Contrast': (lambda mag: None), #Pas opti pour des batch (Super lent)
|
||||
|
@ -70,7 +70,7 @@ TF_dict={ #Dataugv5
|
|||
'Brightness':(lambda x, mag: brightness(x, brightness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.9))),
|
||||
'Posterize': (lambda x, mag: posterize(x, bits=rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||
'Solarize': (lambda x, mag: solarize(x, thresholds=rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
#Color TF (Common mag scale)
|
||||
'+Contrast': (lambda x, mag: contrast(x, contrast_factor=rand_floats(size=x.shape[0], mag=mag, minval=1.0, maxval=1.9))),
|
||||
|
@ -82,7 +82,7 @@ TF_dict={ #Dataugv5
|
|||
'-Brightness':(lambda x, mag: brightness(x, brightness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'-Sharpness':(lambda x, mag: sharpeness(x, sharpness_factor=invScale_rand_floats(size=x.shape[0], mag=mag, minval=0.1, maxval=1.0))),
|
||||
'=Posterize': (lambda x, mag: posterize(x, bits=invScale_rand_floats(size=x.shape[0], mag=mag, minval=4., maxval=8.))),#Perte du gradient
|
||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1] #Pas opti pour des batch
|
||||
'=Solarize': (lambda x, mag: solarize(x, thresholds=invScale_rand_floats(size=x.shape[0], mag=mag, minval=1/256., maxval=256/256.))), #Perte du gradient #=>Image entre [0,1]
|
||||
|
||||
|
||||
'BRotate': (lambda x, mag: rotate(x, angle=rand_floats(size=x.shape[0], mag=mag, maxval=30*3))),
|
||||
|
@ -295,8 +295,7 @@ def equalize(x): #PAS OPTIMISE POUR DES BATCH
|
|||
|
||||
return float_image(x)
|
||||
|
||||
def solarize(x, thresholds): #PAS OPTIMISE POUR DES BATCH
|
||||
# Optimisation : Mask direct sur toute les donnees (Mask = (B,C,H,W)> (B))
|
||||
def solarize(x, thresholds):
|
||||
batch_size, channels, h, w = x.shape
|
||||
#imgs=[]
|
||||
#for idx, t in enumerate(thresholds): #Operation par image
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue