Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
OREPS-OPIX/run.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
126 lines (111 sloc)
5.81 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from EnvGrid import EnvGrid | |
from OREPS import OREPS | |
from SARSA import SARSA | |
from datetime import datetime | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pickle | |
import os | |
if __name__ == '__main__': | |
timeout = 200 | |
num_episodes = 10000 | |
print_interval = 100 | |
class SampleConfig(): | |
height = 5 | |
width = 5 | |
obstacle_list = [(2,2), (1,3), (3,1)] | |
obstacle_stochasticity = 0.0 | |
obs_type = "1D" | |
timeout = timeout | |
default_reward = -0.01 | |
config = SampleConfig() | |
env = EnvGrid(config) | |
def experiment(): | |
result, plot_cfg = [], [] | |
force_move = 1000 | |
eta = np.sqrt(np.log(100)/(100*num_episodes)) | |
result.append(list(OREPS(env, eta, force_move=force_move, gamma=0.0, num_episodes=num_episodes, print_interval=print_interval))) | |
plot_cfg.append(["-", r"OREPS ($\gamma=0$)"]) | |
result.append(list(OREPS(env, eta, force_move=force_move, gamma=np.sqrt(eta), num_episodes=num_episodes, print_interval=print_interval))) | |
plot_cfg.append(["-", r"OREPS-IX ($\gamma>0$)"]) | |
# eta_full = np.sqrt(np.log(100)/(num_episodes)) | |
eta_full = 0.2 | |
gamma_full = np.sqrt(eta_full) | |
result.append(list(OREPS(env, eta_full, optimistic_predictor="latest", force_move=force_move, gamma=gamma_full, num_episodes=num_episodes, print_interval=print_interval))) | |
plot_cfg.append(["-", fr"OREPS-OPIX (latest predictor, $\eta={eta_full}$)"]) | |
result.append(list(OREPS(env, eta_full, reset_period=2, optimistic_predictor="latest", force_move=force_move, gamma=gamma_full, num_episodes=num_episodes, print_interval=print_interval))) | |
plot_cfg.append(["-", fr"OREPS-OPIX (latest predictor, slow reset, $\eta={eta_full}$)"]) | |
result.append(list(OREPS(env, eta_full, reset_period=1/2, optimistic_predictor="latest", force_move=force_move, gamma=gamma_full, num_episodes=num_episodes, print_interval=print_interval))) | |
plot_cfg.append(["-", fr"OREPS-OPIX (latest predictor, fast reset, $\eta={eta_full}$)"]) | |
result.append(list(OREPS(env, eta_full, optimistic_predictor="super", force_move=force_move, gamma=gamma_full, num_episodes=num_episodes, print_interval=print_interval))) | |
plot_cfg.append(["-", fr"OREPS-OPIX (perfect predictor, $\eta={eta_full}$)"]) | |
return result, plot_cfg | |
if not os.path.exists('results'): | |
os.makedirs('results') | |
now = datetime.now().strftime("%y%m%d.%H%M") | |
repeat = 10 | |
result, sq_result = [], [] | |
for rep_idx in range(repeat): | |
rep_result, plot_cfg = experiment() | |
for exp_idx, exp_result in enumerate(rep_result): | |
for type_idx, data in enumerate(exp_result): | |
if rep_idx==0: | |
if type_idx==0: | |
result.append([[x/repeat for x in data]]) # result[exp_idx][0] | |
sq_result.append([[x**2/repeat for x in data]]) | |
else: | |
result[exp_idx].append([x/repeat for x in data]) # result[exp_idx][type_idx] | |
sq_result[exp_idx].append([x**2/repeat for x in data]) | |
else: | |
new_len = min(len(data), len(result[exp_idx][type_idx])) | |
result[exp_idx][type_idx] = [x+y/repeat for x,y in zip(result[exp_idx][type_idx][:new_len], data[:new_len])] | |
sq_result[exp_idx][type_idx] = [x+y**2/repeat for x,y in zip(sq_result[exp_idx][type_idx][:new_len], data[:new_len])] | |
with open(f'results/{now}_{rep_idx}.pkl', 'wb') as file: | |
pickle.dump([result, sq_result], file) | |
# fig, ax = plt.subplots(figsize=(10,6)) | |
# for plot_idx in range(len(result)): | |
# cum_avg = np.cumsum(result[plot_idx][1]) | |
# for i in range(len(cum_avg)): | |
# cum_avg[i] = cum_avg[i]/(i+1) | |
# ax.plot(result[plot_idx][0], cum_avg, plot_cfg[plot_idx][0], label=plot_cfg[plot_idx][1]) | |
# ax.fill_between(result[plot_idx][0], | |
# cum_avg+np.sqrt(sq_result[plot_idx][1]-np.square(result[plot_idx][1])), | |
# cum_avg-np.sqrt(sq_result[plot_idx][1]-np.square(result[plot_idx][1])), | |
# alpha = 0.5) | |
# ax.legend() | |
# plt.xlim(0, num_episodes-print_interval) | |
# plt.xlabel('episodes') | |
# plt.ylabel('average regret') | |
# plt.grid(axis='y') | |
# plt.savefig(f"results/{now}_{rep_idx}_regret.pdf") | |
# if rep_idx == repeat-1: | |
# plt.show() | |
# fig, ax = plt.subplots(figsize=(10,6)) | |
# for plot_idx in range(len(result)): | |
# ax.plot(result[plot_idx][0], result[plot_idx][2], plot_cfg[plot_idx][0], label=plot_cfg[plot_idx][1]) | |
# ax.legend() | |
# plt.xlim(0, num_episodes-print_interval) | |
# plt.xlabel('episodes') | |
# plt.ylabel('prediction error') | |
# plt.grid(axis='y') | |
# plt.savefig(f"results/{now}_{rep_idx}_pred_error.pdf") | |
# if rep_idx == repeat-1: | |
# plt.show() | |
fig, ax = plt.subplots(figsize=(10,6)) | |
for plot_idx in range(len(result)): | |
cum_avg = np.cumsum(result[plot_idx][1]) | |
for i in range(len(cum_avg)): | |
cum_avg[i] = cum_avg[i]/(i+1) | |
ax.plot(result[plot_idx][0], cum_avg, plot_cfg[plot_idx][0], label=plot_cfg[plot_idx][1]) | |
ax.fill_between(result[plot_idx][0], | |
cum_avg+np.sqrt(sq_result[plot_idx][1]-np.square(result[plot_idx][1])), | |
cum_avg-np.sqrt(sq_result[plot_idx][1]-np.square(result[plot_idx][1])), | |
alpha = 0.5) | |
ax.legend() | |
plt.xlim(0, num_episodes-print_interval) | |
plt.xlabel('episodes') | |
plt.ylabel('average regret') | |
plt.grid(axis='y') | |
plt.savefig(f"results/{now}_regret.pdf") | |
if rep_idx == repeat-1: | |
plt.show() |