Add basic TTT

This commit is contained in:
Nicolai Czempin 2017-04-03 23:51:35 +02:00
parent d995b315e1
commit 46f077ef1f
8 changed files with 83 additions and 52 deletions

View File

@ -1,8 +1,8 @@
import gym
import numpy as np
import gym_random_walk
import gym_tic_tac_toe
env = gym.make('random_walk-v0')
env = gym.make('tic_tac_toe-v0')
num_episodes = 20
num_steps_per_episode = 200
@ -10,14 +10,21 @@ num_steps_per_episode = 200
collected_rewards = []
for i in range(num_episodes):
s = env.reset()
print (s)
print ("starting new episode")
env.render()
print ("started")
total_reward = 0
done = False
om = 1
for j in range(num_steps_per_episode):
a = np.random.randint(env.action_space.n)
a = env.action_space.sample()
print (a[0])
#sm = s['on_move']
#print (sm)
a = tuple((om, a[1]))
s1, reward, done, _ = env.step(a)
om = -om
env.render()
total_reward += reward
s = s1

View File

@ -1,10 +0,0 @@
from gym.envs.registration import register
register(
id='random_walk-v0',
entry_point='gym_random_walk.envs:RandomWalkEnv',
)
#register(
# id='foo-extrahard-v0',
# entry_point='gym_foo.envs:FooExtraHardEnv',
#)

View File

@ -1 +0,0 @@
from gym_random_walk.envs.random_walk_env import RandomWalkEnv

View File

@ -1,36 +0,0 @@
import gym
from gym import error, spaces, utils
from gym.utils import seeding
import numpy as np
class RandomWalkEnv(gym.Env):
metadata = {'render.modes': ['human']}
def __init__(self):
self.action_space = spaces.Discrete(2)
self.size = 6
#print("init")
def _step(self, action):
#print("step")
reward = 0
done = False
if (action == 0):
self.state -= 1
if (action == 1):
self.state += 1
if (self.state >= self.size):
reward = 1
done = True
if (self.state <= 0):
done = True
return np.array(self.state), reward, done, {}
def _reset(self):
#print("reset")
print("#self.size:",self.size)
self.state = np.random.randint(1,self.size-1)
print("starting: ", self.state)
def _render(self, mode='human', close=False):
if close:
return
#print("render")
print("current state: ",self.state)

View File

@ -0,0 +1,6 @@
from gym.envs.registration import register
register(
id='tic_tac_toe-v0',
entry_point='gym_tic_tac_toe.envs:TicTacToeEnv',
)

View File

@ -0,0 +1 @@
from gym_tic_tac_toe.envs.tic_tac_toe_env import TicTacToeEnv

View File

@ -0,0 +1,64 @@
import gym
from gym import spaces
import numpy as np
class TicTacToeEnv(gym.Env):
metadata = {'render.modes': ['human']}
def __init__(self):
self.action_space = spaces.Tuple((spaces.Discrete(2), spaces.Discrete(9)))
self.observation_space = spaces.Discrete(3)#Tuple(spaces.Discrete(3), spaces.Discrete(9))
def _step(self, action):
done = False
reward = 0
p, square = action
# p = p*2 - 1
# check move legality
proposed = self.state['board'][square]
om = self.state['on_move']
print ("on move: ", om)
if (proposed != 0): # wrong player, not empty
print("illegal move ", action, ". (square occupied): ", square)
done = True
reward = -om # player who did NOT make the illegal move
if (p != om): # wrong player, not empty
print("illegal move ", action, " not on move: ", p)
done = True
reward = -om # player who did NOT make the illegal move
else:
self.state['board'][square] = p
self.state['on_move'] = -p
# check game over
for i in range(3):
if (self.state['board'][i * 3] == p and self.state['board'][i*3 + 1] == p and self.state['board'][i*3+2] == 2):
reward = p
done = True
break
#TODO other cases
return np.array(self.state), reward, done, {}
def _reset(self):
self.state = {}
self.state['board'] = [0,0,0,0,0,0,0,0,0]
self.state['on_move'] = 1
return self.state
def _render(self, mode='human', close=False):
if close:
return
print("on move: " , self.state['on_move'])
for i in range (9):
print (self.state['board'][i], end=" ")
print()
def move_generator(self):
moves = []
for i in range (9):
if (self.state.state['board'][i]== 0):
p = self.state.on_move
if (p == 2):
p = -1
m = [p, i]
moves.append(m)

View File

@ -1,9 +1,9 @@
from setuptools import setup
from setuptools import find_packages
setup(name='gym_random_walk',
setup(name='gym_tic_tac_toe',
version='0.0.1',
install_requires=['gym'],
url="https://github.com/nczempin/gym-random-walk",
url="https://github.com/nczempin/gym-tic-tac-toe",
packages=find_packages()
)