Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
import numpy as np
from tqdm import tqdm
from OREPS import StateIter
import time
def SARSA(env, alpha, num_episodes = 10000, print_interval=20, gamma=0.9, epsilon=0.1, fixed=True, force_move=None):
state_dim = env.config.height*env.config.width*env.config.height*env.config.width
state_iter = StateIter(env.config.width, env.config.height)
score, episode_length = 0, 0
length_list, score_list, episode_list = [], [], []
regret_list = []
optimal_score = 0
# For training
obs_list, action_list, reward_list = [], [], []
if env.state_dim != 4:
raise Exception("OREPS only supports state_dim==4")
Q = np.zeros((state_dim, env.action_dim)) # No bias on action value
for episode in tqdm(range(num_episodes)):
if force_move and (episode+1)%force_move == 0:
obs, _ = env.reset(True)
else:
obs, _ = env.reset()
obs_idx = state_iter.obs_to_index(obs)
if np.random.random() < epsilon:
action = np.random.choice(env.action_dim)
else:
action = np.argmax(Q[obs_idx, :])
optimal_score += env.get_optimal_score()
done, truncated = False, False
# Run episode
while not done and not truncated:
obs_list.append(obs)
next_obs, reward, done, truncated, _ = env.step(action)
action_list.append(action)
reward_list.append(reward)
score += reward
episode_length += 1
next_obs_idx = state_iter.obs_to_index(next_obs)
if np.random.random() < epsilon:
next_action = np.random.choice(env.action_dim)
else:
next_action = np.argmax(Q[next_obs_idx, :])
obs_idx = state_iter.obs_to_index(obs)
Q[obs_idx,action] = Q[obs_idx,action] + alpha*(reward + gamma*Q[next_obs_idx, next_action] - Q[obs_idx,action])
action=next_action
obs = next_obs
if episode%print_interval == (print_interval-1):
if len(episode_list) == 0:
prev_length = episode_length/print_interval
prev_score = score/print_interval
prev_optimal_score = optimal_score/print_interval
else:
prev_length = (prev_length*(episode-print_interval) + episode_length)/episode
prev_score = (prev_score*(episode-print_interval) + score)/episode
prev_optimal_score = (prev_optimal_score*(episode-print_interval) + optimal_score)/episode
length_list.append(prev_length)
score_list.append(prev_score)
episode_list.append(episode)
regret_list.append(prev_optimal_score-prev_score)
score, episode_length = 0, 0
optimal_score = 0
# return episode_list, length_list, score_list
return episode_list, regret_list, length_list