added faster filtering and convolution, but not working yet for BU

This commit is contained in:
Marco Pedersoli 2020-06-13 20:47:19 -04:00
parent 9d68bc30bd
commit f7436d0002
5 changed files with 295 additions and 11 deletions

17
main.py
View file

@ -49,7 +49,7 @@ checkpoint=False
# Data
print('==> Preparing data..')
dataroot="~/scratch/data"
dataroot="./data"
download_data=False
transform_train = [
# transforms.RandomCrop(32, padding=4),
@ -116,7 +116,7 @@ print('==> Building model..')
# net = MyLeNetMatStochBU() # 10.5s - 45.3% #1.3GB
net=globals()[args.net]()
print(net)
#print(net)
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
@ -244,7 +244,7 @@ def stest(epoch,times=10):
best_acc = acc
import matplotlib.pyplot as plt
def plot_res(log, fig_name='res'):
def plot_res(log, best_acc,fig_name='res'):
"""Save a visual graph of the logs.
Args:
@ -260,7 +260,7 @@ def plot_res(log, fig_name='res'):
ax[0].plot(epochs,[x["test_loss"] for x in log], label='Test')
ax[0].legend()
ax[1].set_title('Acc')
ax[1].set_title('Acc %s'%best_acc)
ax[1].plot(epochs,[x["train_acc"] for x in log], label='Train')
ax[1].plot(epochs,[x["test_acc"] for x in log], label='Test')
ax[1].legend()
@ -270,7 +270,7 @@ def plot_res(log, fig_name='res'):
plt.savefig(fig_name, bbox_inches='tight')
plt.close()
from warmup_scheduler import GradualWarmupScheduler
#from warmup_scheduler import GradualWarmupScheduler
def get_scheduler(schedule, epochs, warmup_mul, warmup_ep):
scheduler=None
if schedule=='cosine':
@ -328,11 +328,12 @@ for epoch in range(start_epoch, start_epoch+args.epochs):
print('\nEpoch: %d' % epoch)
print("Acc : %.2f / %.2f"%(train_acc, test_acc))
print("Loss : %.2f / %.2f"%(train_loss, test_loss))
print('Time:',time.perf_counter() - t0)
exec_time=time.perf_counter() - t0
print('-'*9)
print('Best Acc : %.2f'%best_acc)
print('Training time (s):',exec_time)
print('Training time (min):',exec_time/60)
import json
@ -344,8 +345,8 @@ except:
print("Failed to save logs :",filename)
print(sys.exc_info()[1])
try:
plot_res(log, fig_name=res_folder+filename)
plot_res(log,best_acc, fig_name=res_folder+filename)
print('Plot :\"',res_folder+filename, '\" saved !')
except:
print("Failed to plot res")
print(sys.exc_info()[1])
print(sys.exc_info()[1])