-
Notifications
You must be signed in to change notification settings - Fork 1
/
experience_replay.py
83 lines (69 loc) · 2.56 KB
/
experience_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Experience Replay
# Importing the libraries
import numpy as np
from collections import namedtuple, deque
# Defining one Step
Step = namedtuple('Step', ['state', 'action', 'reward', 'done'])
# Making the AI progress on several (n_step) steps
class NStepProgress:
realtime_render=True
def __init__(self, env, ai, n_step):
self.ai = ai
self.rewards = []
self.env = env
self.n_step = n_step
def __iter__(self):
state = self.env.reset()
history = deque()
reward = 0.0
while True:
action = self.ai(np.array([state]))[0][0]
if(self.realtime_render):
self.env.render()
next_state, r, is_done, _ = self.env.step(action)
if(self.realtime_render):
self.env.render()
reward += r
history.append(Step(state = state, action = action, reward = r, done = is_done))
while len(history) > self.n_step + 1:
history.popleft()
if len(history) == self.n_step + 1:
yield tuple(history)
state = next_state
if is_done:
if len(history) > self.n_step + 1:
history.popleft()
while len(history) >= 1:
yield tuple(history)
history.popleft()
self.rewards.append(reward)
reward = 0.0
state = self.env.reset()
history.clear()
def disable_rendering(self):
self.realtime_render=False
def rewards_steps(self):
rewards_steps = self.rewards
self.rewards = []
return rewards_steps
# Implementing Experience Replay
class ReplayMemory:
def __init__(self, n_steps, capacity = 10000):
self.capacity = capacity
self.n_steps = n_steps
self.n_steps_iter = iter(n_steps)
self.buffer = deque()
def sample_batch(self, batch_size): # creates an iterator that returns random batches
ofs = 0
vals = list(self.buffer)
np.random.shuffle(vals)
while (ofs+1)*batch_size <= len(self.buffer):
yield vals[ofs*batch_size:(ofs+1)*batch_size]
ofs += 1
def run_steps(self, samples):
while samples > 0:
entry = next(self.n_steps_iter) # consecutive steps
self.buffer.append(entry) # we put for the current episode
samples -= 1
while len(self.buffer) > self.capacity: # we accumulate no more than the capacity (10000)
self.buffer.popleft()