-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
washim-uddin-mondal
committed
Sep 7, 2022
0 parents
commit d21b27d
Showing
8 changed files
with
495 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Results/ | ||
.DS_Store | ||
.idea/ | ||
Scripts/__pycache__/ |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from Scripts.Algorithm import train, evaluateMARLNonLocal, evaluateMARLLocal | ||
from Scripts.Parameters import ParseInput | ||
import time | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import os | ||
|
||
if __name__ == '__main__': | ||
args = ParseInput() | ||
|
||
t0 = time.time() | ||
|
||
indexN = 0 | ||
valueLocalArray = np.zeros(args.numN) | ||
valueLocalArraySD = np.zeros(args.numN) | ||
|
||
valueNonLocalArray = np.zeros(args.numN) | ||
valueNonLocalArraySD = np.zeros(args.numN) | ||
|
||
ErrorArray = np.zeros(args.numN) | ||
ErrorArraySD = np.zeros(args.numN) | ||
|
||
NVec = np.zeros(args.numN) | ||
|
||
if args.train: | ||
print('Training is in progress.') | ||
train(args) | ||
|
||
print('Evaluation is in progress.') | ||
while indexN < args.numN: | ||
N = args.minN + indexN * args.divN | ||
NVec[indexN] = N | ||
|
||
for _ in range(0, args.maxSeed): | ||
valueLocal = evaluateMARLLocal(args, N) | ||
valueLocal = np.array(valueLocal.detach()) | ||
|
||
valueLocalArray[indexN] += valueLocal/args.maxSeed | ||
valueLocalArraySD[indexN] += valueLocal ** 2 / args.maxSeed | ||
|
||
valueNonLocal = evaluateMARLNonLocal(args, N) | ||
valueNonLocal = np.array(valueNonLocal.detach()) | ||
|
||
valueNonLocalArray[indexN] += valueNonLocal/args.maxSeed | ||
valueNonLocalArraySD[indexN] += valueNonLocal**2/args.maxSeed | ||
|
||
Error = np.abs(valueNonLocal - valueLocal) | ||
ErrorArray[indexN] += Error/args.maxSeed | ||
ErrorArraySD[indexN] += Error**2/args.maxSeed | ||
|
||
indexN += 1 | ||
print(f'N: {N}') | ||
|
||
valueLocalArraySD = np.sqrt(np.maximum(0, valueLocalArraySD - valueLocalArray ** 2)) | ||
valueNonLocalArraySD = np.sqrt(np.maximum(0, valueNonLocalArraySD - valueNonLocalArray ** 2)) | ||
ErrorArraySD = np.sqrt(np.maximum(0, ErrorArraySD - ErrorArray ** 2)) | ||
|
||
if not os.path.exists('Results'): | ||
os.mkdir('Results') | ||
|
||
plt.figure() | ||
plt.xlabel('N') | ||
plt.ylabel('Values') | ||
plt.plot(NVec, valueLocalArray, label='Local') | ||
plt.fill_between(NVec, valueLocalArray - valueLocalArraySD, valueLocalArray + valueLocalArraySD, alpha=0.3) | ||
plt.plot(NVec, valueNonLocalArray, label='Non-Local') | ||
plt.fill_between(NVec, valueNonLocalArray - valueNonLocalArraySD, valueNonLocalArray + valueNonLocalArraySD, alpha=0.3) | ||
plt.legend() | ||
plt.savefig(f'Results/Values.png') | ||
|
||
plt.figure() | ||
plt.xlabel('N') | ||
plt.ylabel('Error') | ||
plt.plot(NVec, ErrorArray) | ||
plt.fill_between(NVec, ErrorArray - ErrorArraySD, ErrorArray + ErrorArraySD, alpha=0.3) | ||
plt.savefig(f'Results/Error.png') | ||
|
||
t1 = time.time() | ||
|
||
print(f'Elapsed time is {t1-t0} sec') |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Introduction | ||
|
||
This repository contains codes that are used for generating numerical results in the following paper: | ||
|
||
"On the Near-Optimality of Local Policies in Large Cooperative | ||
Multi-Agent Reinforcement Learning", Transactions on Machine Learning Research, 2022. | ||
|
||
# Parameters | ||
|
||
Various parameters used in the experiments can be found in Scripts/Parameters.py file. | ||
|
||
# Results | ||
|
||
Generated results will be stored in Results folder (will be created on the fly). | ||
Some pre-generated results are available for display in the Display folder. Specifically, | ||
Fig. 1 depicts the percentage error between the values generated by local and non-local policies in an N-agent system | ||
as a function of N. | ||
|
||
# Run Experiments | ||
|
||
``` | ||
python3 Main.py | ||
``` | ||
|
||
# Command Line Options | ||
|
||
Various command line options are given below: | ||
|
||
``` | ||
--train : if training is required from scratch, otherwise a pre-trained model will be used | ||
--minN : minimum value of N | ||
--numN : number of N values | ||
--divN : difference between two consecutive N values | ||
--maxSeed: number of random seeds | ||
``` |
Oops, something went wrong.