Permalink
Cannot retrieve contributors at this time
NearOptimalLocalPolicy/Main.py
Go to filefrom Scripts.Algorithm import train, evaluateMARLNonLocal, evaluateMARLLocal | |
from Scripts.Parameters import ParseInput | |
import time | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
if __name__ == '__main__': | |
args = ParseInput() | |
t0 = time.time() | |
indexN = 0 | |
valueLocalArray = np.zeros(args.numN) | |
valueLocalArraySD = np.zeros(args.numN) | |
valueNonLocalArray = np.zeros(args.numN) | |
valueNonLocalArraySD = np.zeros(args.numN) | |
ErrorArray = np.zeros(args.numN) | |
ErrorArraySD = np.zeros(args.numN) | |
NVec = np.zeros(args.numN) | |
if args.train: | |
print('Training is in progress.') | |
train(args) | |
print('Evaluation is in progress.') | |
while indexN < args.numN: | |
N = args.minN + indexN * args.divN | |
NVec[indexN] = N | |
for _ in range(0, args.maxSeed): | |
valueLocal = evaluateMARLLocal(args, N) | |
valueLocal = np.array(valueLocal.detach()) | |
valueLocalArray[indexN] += valueLocal/args.maxSeed | |
valueLocalArraySD[indexN] += valueLocal ** 2 / args.maxSeed | |
valueNonLocal = evaluateMARLNonLocal(args, N) | |
valueNonLocal = np.array(valueNonLocal.detach()) | |
valueNonLocalArray[indexN] += valueNonLocal/args.maxSeed | |
valueNonLocalArraySD[indexN] += valueNonLocal**2/args.maxSeed | |
Error = np.abs(valueNonLocal - valueLocal) | |
ErrorArray[indexN] += Error/args.maxSeed | |
ErrorArraySD[indexN] += Error**2/args.maxSeed | |
indexN += 1 | |
print(f'N: {N}') | |
valueLocalArraySD = np.sqrt(np.maximum(0, valueLocalArraySD - valueLocalArray ** 2)) | |
valueNonLocalArraySD = np.sqrt(np.maximum(0, valueNonLocalArraySD - valueNonLocalArray ** 2)) | |
ErrorArraySD = np.sqrt(np.maximum(0, ErrorArraySD - ErrorArray ** 2)) | |
if not os.path.exists('Results'): | |
os.mkdir('Results') | |
plt.figure() | |
plt.xlabel('N') | |
plt.ylabel('Values') | |
plt.plot(NVec, valueLocalArray, label='Local') | |
plt.fill_between(NVec, valueLocalArray - valueLocalArraySD, valueLocalArray + valueLocalArraySD, alpha=0.3) | |
plt.plot(NVec, valueNonLocalArray, label='Non-Local') | |
plt.fill_between(NVec, valueNonLocalArray - valueNonLocalArraySD, valueNonLocalArray + valueNonLocalArraySD, alpha=0.3) | |
plt.legend() | |
plt.savefig(f'Results/Values.png') | |
plt.figure() | |
plt.xlabel('N') | |
plt.ylabel('Error') | |
plt.plot(NVec, ErrorArray) | |
plt.fill_between(NVec, ErrorArray - ErrorArraySD, ErrorArray + ErrorArraySD, alpha=0.3) | |
plt.savefig(f'Results/Error.png') | |
t1 = time.time() | |
print(f'Elapsed time is {t1-t0} sec') |