diff --git a/gym_random_walk/envs/random_walk_env.py b/gym_random_walk/envs/random_walk_env.py index bc9d22a..a432785 100644 --- a/gym_random_walk/envs/random_walk_env.py +++ b/gym_random_walk/envs/random_walk_env.py @@ -8,6 +8,7 @@ class RandomWalkEnv(gym.Env): def __init__(self): self.action_space = spaces.Discrete(2) + self.size = 6 #print("init") def _step(self, action): #print("step") @@ -17,7 +18,7 @@ class RandomWalkEnv(gym.Env): self.state -= 1 if (action == 1): self.state += 1 - if (self.state >= 6): + if (self.state >= self.size): reward = 1 done = True if (self.state <= 0): @@ -25,7 +26,11 @@ class RandomWalkEnv(gym.Env): return np.array(self.state), reward, done, {} def _reset(self): #print("reset") - self.state = 1 # TODO start in a random position + print("#self.size:",self.size) + self.state = np.random.randint(1,self.size-1) + print("starting: ", self.state) def _render(self, mode='human', close=False): + if close: + return #print("render") - print(self.state) + print("current state: ",self.state)