-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
washim-uddin-mondal
committed
May 25, 2023
0 parents
commit 4d75ced
Showing
7 changed files
with
505 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Results/ | ||
.DS_Store | ||
.idea/ | ||
Scripts/__pycache__/ |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from Scripts.Algorithm import train, evaluateMFC, evaluateMARL | ||
from Scripts.Parameters import ParseInput | ||
import time | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import os | ||
import logging | ||
|
||
if __name__ == '__main__': | ||
args = ParseInput() | ||
|
||
if not os.path.exists('Results'): | ||
os.mkdir('Results') | ||
|
||
# Logging | ||
args.logFileName = 'Results/progress.log' | ||
open(args.logFileName, 'w').close() | ||
logging.basicConfig(filename=args.logFileName, | ||
format='%(asctime)s %(message)s', | ||
filemode='w') | ||
args.logger = logging.getLogger() | ||
args.logger.setLevel(logging.INFO) | ||
|
||
t0 = time.time() | ||
|
||
indexN = 0 | ||
valueRewardMFCArray = np.zeros(args.numN) | ||
valueRewardMFCArraySD = np.zeros(args.numN) | ||
|
||
valueRewardMARLArray = np.zeros(args.numN) | ||
valueRewardMARLArraySD = np.zeros(args.numN) | ||
|
||
RewardErrorArray = np.zeros(args.numN) | ||
RewardErrorArraySD = np.zeros(args.numN) | ||
|
||
NVec = np.zeros(args.numN) | ||
|
||
if args.train: | ||
args.logger.info('Training is in progress.') | ||
train(args) | ||
|
||
args.logger.info('Evaluation is in progress.') | ||
while indexN < args.numN: | ||
N = args.minN + indexN * args.divN | ||
NVec[indexN] = N | ||
|
||
for _ in range(0, args.maxSeed): | ||
valueRewardMFC = evaluateMFC(args) | ||
valueRewardMFC = np.array(valueRewardMFC.detach()) | ||
|
||
valueRewardMFCArray[indexN] += valueRewardMFC/args.maxSeed | ||
valueRewardMFCArraySD[indexN] += valueRewardMFC ** 2 / args.maxSeed | ||
|
||
valueRewardMARL = evaluateMARL(args, N) | ||
valueRewardMARL = np.array(valueRewardMARL.detach()) | ||
|
||
valueRewardMARLArray[indexN] += valueRewardMARL/args.maxSeed | ||
valueRewardMARLArraySD[indexN] += valueRewardMARL**2/args.maxSeed | ||
|
||
RewardError = np.abs(valueRewardMARL - valueRewardMFC) | ||
RewardErrorArray[indexN] += RewardError/args.maxSeed | ||
RewardErrorArraySD[indexN] += RewardError**2/args.maxSeed | ||
|
||
indexN += 1 | ||
args.logger.info(f'Evaluation N: {N}') | ||
|
||
valueRewardMFCArraySD = np.sqrt(np.maximum(0, valueRewardMFCArraySD - valueRewardMFCArray ** 2)) | ||
valueRewardMARLArraySD = np.sqrt(np.maximum(0, valueRewardMARLArraySD - valueRewardMARLArray ** 2)) | ||
RewardErrorArraySD = np.sqrt(np.maximum(0, RewardErrorArraySD - RewardErrorArray ** 2)) | ||
|
||
plt.figure() | ||
plt.xlabel('Number of Agents') | ||
plt.ylabel('Reward Values') | ||
plt.plot(NVec, valueRewardMFCArray, label='MFC') | ||
plt.fill_between(NVec, valueRewardMFCArray - valueRewardMFCArraySD, valueRewardMFCArray + valueRewardMFCArraySD, alpha=0.3) | ||
plt.plot(NVec, valueRewardMARLArray, label='MARL') | ||
plt.fill_between(NVec, valueRewardMARLArray - valueRewardMARLArraySD, valueRewardMARLArray + valueRewardMARLArraySD, alpha=0.3) | ||
plt.legend() | ||
plt.savefig(f'Results/RewardValues.png') | ||
|
||
plt.figure() | ||
plt.xlabel('Number of Agents') | ||
plt.ylabel('Error') | ||
plt.plot(NVec, RewardErrorArray) | ||
plt.fill_between(NVec, RewardErrorArray - RewardErrorArraySD, RewardErrorArray + RewardErrorArraySD, alpha=0.3) | ||
plt.savefig(f'Results/RewardError.png') | ||
|
||
t1 = time.time() | ||
|
||
args.logger.info(f'Elapsed time is {t1-t0} sec') |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Introduction | ||
|
||
This repository contains codes that are used for generating numerical results in the following paper: | ||
|
||
"Mean-Field Control based Approximation of Multi-Agent Reinforcement Learning in Presence of a Shared Global State", | ||
Transactions on Machine Learning Research, May, 2023. | ||
|
||
[[arXiv]](https://arxiv.org/abs/2301.06889) [[TMLR]]() | ||
|
||
``` | ||
@article{mondal2023mean, | ||
title={Mean-Field Control based Approximation of Multi-Agent Reinforcement Learning in Presence of a Non-decomposable Shared Global State}, | ||
author={Mondal, Washim Uddin and Aggarwal, Vaneet and Ukkusuri, Satish V}, | ||
journal={arXiv preprint arXiv:2301.06889}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
# Parameters | ||
|
||
Various parameters used in the experiments can be found in [Scripts/Parameters.py]() file. | ||
|
||
# Software and Packages | ||
|
||
``` | ||
python 3.8.12 | ||
pytorch 1.10.1 | ||
numpy 1.21.2 | ||
matplotlib 3.5.0 | ||
``` | ||
# Results | ||
|
||
Generated results will be stored in Results folder (will be created on the fly). | ||
Some pre-generated results are available for display in the Display folder. Specifically, | ||
[Fig. 1]() depicts the error as a function | ||
of N (the number of agents). | ||
|
||
# Run Experiments | ||
|
||
``` | ||
python3 Main.py | ||
``` | ||
|
||
The progress of the experiment is logged in Results/progress.log | ||
|
||
# Command Line Options | ||
|
||
Various command line options are given below: | ||
|
||
``` | ||
--train : if training is required from scratch, otherwise a pre-trained model will be used | ||
--minN : minimum value of N | ||
--numN : number of N values | ||
--divN : difference between two consecutive N values | ||
--maxSeed: number of random seeds | ||
``` |
Oops, something went wrong.