mirror of
https://github.com/AntoineHX/BU_Stoch_pool.git
synced 2025-05-04 01:30:48 +02:00
added faster filtering and convolution, but not working yet for BU
This commit is contained in:
parent
9d68bc30bd
commit
f7436d0002
5 changed files with 295 additions and 11 deletions
17
main.py
17
main.py
|
@ -49,7 +49,7 @@ checkpoint=False
|
|||
|
||||
# Data
|
||||
print('==> Preparing data..')
|
||||
dataroot="~/scratch/data"
|
||||
dataroot="./data"
|
||||
download_data=False
|
||||
transform_train = [
|
||||
# transforms.RandomCrop(32, padding=4),
|
||||
|
@ -116,7 +116,7 @@ print('==> Building model..')
|
|||
# net = MyLeNetMatStochBU() # 10.5s - 45.3% #1.3GB
|
||||
|
||||
net=globals()[args.net]()
|
||||
print(net)
|
||||
#print(net)
|
||||
net = net.to(device)
|
||||
if device == 'cuda':
|
||||
net = torch.nn.DataParallel(net)
|
||||
|
@ -244,7 +244,7 @@ def stest(epoch,times=10):
|
|||
best_acc = acc
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
def plot_res(log, fig_name='res'):
|
||||
def plot_res(log, best_acc,fig_name='res'):
|
||||
"""Save a visual graph of the logs.
|
||||
|
||||
Args:
|
||||
|
@ -260,7 +260,7 @@ def plot_res(log, fig_name='res'):
|
|||
ax[0].plot(epochs,[x["test_loss"] for x in log], label='Test')
|
||||
ax[0].legend()
|
||||
|
||||
ax[1].set_title('Acc')
|
||||
ax[1].set_title('Acc %s'%best_acc)
|
||||
ax[1].plot(epochs,[x["train_acc"] for x in log], label='Train')
|
||||
ax[1].plot(epochs,[x["test_acc"] for x in log], label='Test')
|
||||
ax[1].legend()
|
||||
|
@ -270,7 +270,7 @@ def plot_res(log, fig_name='res'):
|
|||
plt.savefig(fig_name, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
from warmup_scheduler import GradualWarmupScheduler
|
||||
#from warmup_scheduler import GradualWarmupScheduler
|
||||
def get_scheduler(schedule, epochs, warmup_mul, warmup_ep):
|
||||
scheduler=None
|
||||
if schedule=='cosine':
|
||||
|
@ -328,11 +328,12 @@ for epoch in range(start_epoch, start_epoch+args.epochs):
|
|||
print('\nEpoch: %d' % epoch)
|
||||
print("Acc : %.2f / %.2f"%(train_acc, test_acc))
|
||||
print("Loss : %.2f / %.2f"%(train_loss, test_loss))
|
||||
print('Time:',time.perf_counter() - t0)
|
||||
|
||||
exec_time=time.perf_counter() - t0
|
||||
print('-'*9)
|
||||
print('Best Acc : %.2f'%best_acc)
|
||||
print('Training time (s):',exec_time)
|
||||
print('Training time (min):',exec_time/60)
|
||||
|
||||
|
||||
import json
|
||||
|
@ -344,8 +345,8 @@ except:
|
|||
print("Failed to save logs :",filename)
|
||||
print(sys.exc_info()[1])
|
||||
try:
|
||||
plot_res(log, fig_name=res_folder+filename)
|
||||
plot_res(log,best_acc, fig_name=res_folder+filename)
|
||||
print('Plot :\"',res_folder+filename, '\" saved !')
|
||||
except:
|
||||
print("Failed to plot res")
|
||||
print(sys.exc_info()[1])
|
||||
print(sys.exc_info()[1])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue