Permalink
Sep 7, 2022
Newer
100644
80 lines (60 sloc)
2.57 KB

1
from Scripts.Algorithm import train, evaluateMARLNonLocal, evaluateMARLLocal
2
from Scripts.Parameters import ParseInput
3
import time
4
import numpy as np
5
import matplotlib.pyplot as plt
6
import os
7
8
if __name__ == '__main__':
9
args = ParseInput()
10
11
t0 = time.time()
12
13
indexN = 0
14
valueLocalArray = np.zeros(args.numN)
15
valueLocalArraySD = np.zeros(args.numN)
16
17
valueNonLocalArray = np.zeros(args.numN)
18
valueNonLocalArraySD = np.zeros(args.numN)
19
20
ErrorArray = np.zeros(args.numN)
21
ErrorArraySD = np.zeros(args.numN)
22
23
NVec = np.zeros(args.numN)
24
25
if args.train:
26
print('Training is in progress.')
27
train(args)
28
29
print('Evaluation is in progress.')
30
while indexN < args.numN:
31
N = args.minN + indexN * args.divN
32
NVec[indexN] = N
33
34
for _ in range(0, args.maxSeed):
35
valueLocal = evaluateMARLLocal(args, N)
36
valueLocal = np.array(valueLocal.detach())
37
38
valueLocalArray[indexN] += valueLocal/args.maxSeed
39
valueLocalArraySD[indexN] += valueLocal ** 2 / args.maxSeed
40
41
valueNonLocal = evaluateMARLNonLocal(args, N)
42
valueNonLocal = np.array(valueNonLocal.detach())
43
44
valueNonLocalArray[indexN] += valueNonLocal/args.maxSeed
45
valueNonLocalArraySD[indexN] += valueNonLocal**2/args.maxSeed
46
47
Error = np.abs(valueNonLocal - valueLocal)
48
ErrorArray[indexN] += Error/args.maxSeed
49
ErrorArraySD[indexN] += Error**2/args.maxSeed
50
51
indexN += 1
52
print(f'N: {N}')
53
54
valueLocalArraySD = np.sqrt(np.maximum(0, valueLocalArraySD - valueLocalArray ** 2))
55
valueNonLocalArraySD = np.sqrt(np.maximum(0, valueNonLocalArraySD - valueNonLocalArray ** 2))
56
ErrorArraySD = np.sqrt(np.maximum(0, ErrorArraySD - ErrorArray ** 2))
57
58
if not os.path.exists('Results'):
59
os.mkdir('Results')
60
61
plt.figure()
62
plt.xlabel('N')
63
plt.ylabel('Values')
64
plt.plot(NVec, valueLocalArray, label='Local')
65
plt.fill_between(NVec, valueLocalArray - valueLocalArraySD, valueLocalArray + valueLocalArraySD, alpha=0.3)
66
plt.plot(NVec, valueNonLocalArray, label='Non-Local')
67
plt.fill_between(NVec, valueNonLocalArray - valueNonLocalArraySD, valueNonLocalArray + valueNonLocalArraySD, alpha=0.3)
68
plt.legend()
69
plt.savefig(f'Results/Values.png')
70
71
plt.figure()
72
plt.xlabel('N')
73
plt.ylabel('Error')
74
plt.plot(NVec, ErrorArray)
75
plt.fill_between(NVec, ErrorArray - ErrorArraySD, ErrorArray + ErrorArraySD, alpha=0.3)
76
plt.savefig(f'Results/Error.png')
77
78
t1 = time.time()
79
80
print(f'Elapsed time is {t1-t0} sec')