import random
import numpy as np
from gymnasium import Space
from abc import ABC, abstractmethod
from typing import Optional, Union, Dict
from collections import deque
from xuance.common.common_tools import discount_cumsum
from xuance.common.segtree_tool import SumSegmentTree, MinSegmentTree
from xuance.environment.utils import space2shape
[docs]
def create_memory(shape: Optional[Union[tuple, dict]],
n_envs: int,
n_size: int,
dtype: type = np.float32):
"""
Create a numpy array for memory data.
Args:
shape: data shape.
n_envs: number of parallel environments.
n_size: length of data sequence for each environment.
dtype: numpy data type.
Returns:
An empty memory space to store data. (initial: numpy.zeros())
"""
if shape is None:
return None
elif isinstance(shape, dict):
memory = {}
for key, value in shape.items():
if value is None: # save an object type
memory[key] = np.zeros([n_envs, n_size], dtype=object)
else:
memory[key] = np.zeros([n_envs, n_size] + list(value), dtype=dtype)
return memory
elif isinstance(shape, tuple):
return np.zeros([n_envs, n_size] + list(shape), dtype)
else:
raise NotImplementedError
[docs]
def store_element(data: Optional[Union[np.ndarray, dict, float]],
memory: Union[dict, np.ndarray],
ptr: int):
"""
Insert a step of data into current memory.
Args:
data: target data that to be stored.
memory: the memory where data will be stored.
ptr: pointer to the location for the data.
"""
if data is None:
return
elif isinstance(data, dict):
for key, value in data.items():
memory[key][:, ptr] = data[key]
else:
memory[:, ptr] = data
[docs]
def sample_batch(memory: Optional[Union[np.ndarray, dict]],
index: Optional[Union[np.ndarray, tuple]]):
"""
Sample a batch of data from the selected memory.
Args:
memory: memory that contains experience data.
index: pointer to the location for the selected data.
Returns:
A batch of data.
"""
if memory is None:
return None
elif isinstance(memory, dict):
batch = {}
for key, value in memory.items():
batch[key] = value[index]
return batch
else:
return memory[index]
[docs]
class Buffer(ABC):
"""
Basic buffer single-agent DRL algorithms.
Args:
observation_space: the space for observation data.
action_space: the space for action data.
auxiliary_info_shape: the shape for auxiliary data if needed.
"""
def __init__(
self,
observation_space: Space,
action_space: Space,
auxiliary_info_shape: Optional[dict],
num_envs: int,
buffer_size: int
):
assert buffer_size % num_envs == 0, "buffer_size must be divisible by the number of envs (parallels)"
self.observation_space = observation_space
self.action_space = action_space
self.auxiliary_shape = auxiliary_info_shape
# Pre-define the data that might be stored in replay buffer for training.
self.observations: Optional[np.ndarray] = None
self.next_observations: Optional[np.ndarray] = None
self.actions: Optional[np.ndarray] = None
self.auxiliary_infos: Optional[np.ndarray, dict] = None
self.rewards: Optional[np.ndarray] = None
self.terminals: Optional[np.ndarray] = None
self.returns: Optional[np.ndarray] = None
self.values: Optional[np.ndarray] = None
self.n_envs = num_envs
self.buffer_size = buffer_size
self.n_size = self.buffer_size // self.n_envs
self.advantages: Optional[np.ndarray] = None
self.ptr = 0 # last data pointer
self.size = 0 # current buffer size per environment.
@property
def full(self):
return self.size >= self.n_size
[docs]
@abstractmethod
def store(self, *args):
raise NotImplementedError
[docs]
@abstractmethod
def clear(self, *args):
raise NotImplementedError
[docs]
@abstractmethod
def sample(self, *args):
raise NotImplementedError
[docs]
def finish_path(self, *args):
pass
[docs]
class EpisodeBuffer:
"""
Episode buffer for DRQN agent.
"""
def __init__(self):
self.obs = []
self.action = []
self.reward = []
self.done = []
[docs]
def put(self, transition):
self.obs.append(transition[0])
self.action.append(transition[1])
self.reward.append(transition[2])
self.done.append(transition[3])
[docs]
def sample(self, lookup_step=None, idx=None) -> Dict[str, np.ndarray]:
obs = np.array(self.obs)
action = np.array(self.action)
reward = np.array(self.reward)
done = np.array(self.done)
obs = obs[idx:idx + lookup_step + 1]
action = action[idx:idx + lookup_step]
reward = reward[idx:idx + lookup_step]
done = done[idx:idx + lookup_step]
return dict(obs=obs,
acts=action,
rews=reward,
done=done)
def __len__(self) -> int:
return len(self.action)
[docs]
class DummyOnPolicyBuffer(Buffer):
"""
Replay buffer for on-policy DRL algorithms.
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
horizon_size: max length of steps to store for one environment.
use_gae: if use GAE trick.
use_advnorm: if use Advantage normalization trick.
gamma: discount factor.
gae_lam: gae lambda.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
horizon_size: int,
use_gae: bool = True,
use_advnorm: bool = True,
gamma: float = 0.99,
gae_lam: float = 0.95):
self.buffer_size = horizon_size * n_envs
super(DummyOnPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, self.buffer_size)
self.horizon_size = horizon_size
self.n_size = self.horizon_size
self.use_gae, self.use_advnorm = use_gae, use_advnorm
self.gamma, self.gae_lam = gamma, gae_lam
self.start_ids = np.zeros(self.n_envs, np.int64)
self.clear()
@property
def full(self):
return self.size >= self.n_size
[docs]
def clear(self):
self.ptr, self.size = 0, 0
self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.rewards = create_memory((), self.n_envs, self.n_size)
self.returns = create_memory((), self.n_envs, self.n_size)
self.values = create_memory((), self.n_envs, self.n_size)
self.terminals = create_memory((), self.n_envs, self.n_size)
self.advantages = create_memory((), self.n_envs, self.n_size)
self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size)
[docs]
def store(self, obs, acts, rews, value, terminals, aux_info=None):
store_element(obs, self.observations, self.ptr)
store_element(acts, self.actions, self.ptr)
store_element(rews, self.rewards, self.ptr)
store_element(value, self.values, self.ptr)
store_element(terminals, self.terminals, self.ptr)
store_element(aux_info, self.auxiliary_infos, self.ptr)
self.ptr = (self.ptr + 1) % self.n_size
self.size = min(self.size + 1, self.n_size)
[docs]
def finish_path(self, val, i):
if self.full:
path_slice = np.arange(self.start_ids[i], self.n_size).astype(np.int32)
else:
path_slice = np.arange(self.start_ids[i], self.ptr).astype(np.int32)
vs = np.append(np.array(self.values[i, path_slice]), [val], axis=0)
if self.use_gae: # use gae
rewards = np.array(self.rewards[i, path_slice])
advantages = np.zeros_like(rewards)
dones = np.array(self.terminals[i, path_slice])
last_gae_lam = 0
step_nums = len(path_slice)
for t in reversed(range(step_nums)):
delta = rewards[t] + (1 - dones[t]) * self.gamma * vs[t + 1] - vs[t]
advantages[t] = last_gae_lam = delta + (1 - dones[t]) * self.gamma * self.gae_lam * last_gae_lam
returns = advantages + vs[:-1]
else:
rewards = np.append(np.array(self.rewards[i, path_slice]), [val], axis=0)
returns = discount_cumsum(rewards, self.gamma)[:-1]
advantages = rewards[:-1] + self.gamma * vs[1:] - vs[:-1]
self.returns[i, path_slice] = returns
self.advantages[i, path_slice] = advantages
self.start_ids[i] = self.ptr
[docs]
def sample(self, indexes):
assert self.full, "Not enough transitions for on-policy buffer to random sample"
env_choices, step_choices = divmod(indexes, self.n_size)
samples_dict = {
'obs': sample_batch(self.observations, tuple([env_choices, step_choices])),
'actions': sample_batch(self.actions, tuple([env_choices, step_choices])),
'returns': sample_batch(self.returns, tuple([env_choices, step_choices])),
'values': sample_batch(self.values, tuple([env_choices, step_choices])),
'aux_batch': sample_batch(self.auxiliary_infos, tuple([env_choices, step_choices])),
'batch_size': len(indexes),
}
adv_batch = sample_batch(self.advantages, tuple([env_choices, step_choices]))
if self.use_advnorm:
adv_batch = (adv_batch - np.mean(adv_batch)) / (np.std(adv_batch) + 1e-8)
samples_dict.update({
'advantages': adv_batch
})
return samples_dict
[docs]
class DummyOnPolicyBuffer_Atari(DummyOnPolicyBuffer):
"""
Replay buffer for on-policy DRL algorithms and Atari tasks.
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
horizon_size: max length of steps to store for one environment.
use_gae: if use GAE trick.
use_advnorm: if use Advantage normalization trick.
gamma: discount factor.
gae_lam: gae lambda.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
horizon_size: int,
use_gae: bool = True,
use_advnorm: bool = True,
gamma: float = 0.99,
gae_lam: float = 0.95):
super(DummyOnPolicyBuffer_Atari, self).__init__(observation_space, action_space, auxiliary_shape,
n_envs, horizon_size, use_gae, use_advnorm, gamma, gae_lam)
[docs]
def clear(self):
self.ptr, self.size = 0, 0
self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size, np.uint8)
self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.rewards = create_memory((), self.n_envs, self.n_size)
self.returns = create_memory((), self.n_envs, self.n_size)
self.values = create_memory((), self.n_envs, self.n_size)
self.terminals = create_memory((), self.n_envs, self.n_size)
self.advantages = create_memory((), self.n_envs, self.n_size)
self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size)
[docs]
class DummyOffPolicyBuffer(Buffer):
"""
Replay buffer for off-policy DRL algorithms.
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
buffer_size: the total size of the replay buffer.
batch_size: size of transition data for a batch of sample.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
buffer_size: int,
batch_size: int):
super(DummyOffPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, buffer_size)
self.batch_size = batch_size
assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)"
self.n_size = buffer_size // self.n_envs
self.clear()
[docs]
def clear(self):
self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size)
self.rewards = create_memory((), self.n_envs, self.n_size)
self.terminals = create_memory((), self.n_envs, self.n_size)
[docs]
def store(self, obs, acts, rews, terminals, next_obs):
store_element(obs, self.observations, self.ptr)
store_element(acts, self.actions, self.ptr)
store_element(rews, self.rewards, self.ptr)
store_element(terminals, self.terminals, self.ptr)
store_element(next_obs, self.next_observations, self.ptr)
self.ptr = (self.ptr + 1) % self.n_size
self.size = min(self.size + 1, self.n_size)
[docs]
def sample(self, batch_size=None):
bs = self.batch_size if batch_size is None else batch_size
env_choices = np.random.choice(self.n_envs, bs)
step_choices = np.random.choice(self.size, bs)
samples_dict = {
'obs': sample_batch(self.observations, tuple([env_choices, step_choices])),
'actions': sample_batch(self.actions, tuple([env_choices, step_choices])),
'obs_next': sample_batch(self.next_observations, tuple([env_choices, step_choices])),
'rewards': sample_batch(self.rewards, tuple([env_choices, step_choices])),
'terminals': sample_batch(self.terminals, tuple([env_choices, step_choices])),
'batch_size': bs,
}
return samples_dict
[docs]
class RecurrentOffPolicyBuffer(Buffer):
"""
Replay buffer for DRQN-based algorithms.
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
buffer_size: the size of replay buffer that stores episodes of data.
batch_size: batch size of transition data for a sample.
episode_length: data length for an episode.
lookup_length: the length of history data.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
buffer_size: int,
batch_size: int,
episode_length: int,
lookup_length: int):
super(RecurrentOffPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape,
n_envs, buffer_size)
self.episode_length, self.batch_size = episode_length, batch_size
assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)"
self.n_size = self.buffer_size // self.n_envs
self.lookup_length = lookup_length
self.memory = deque(maxlen=self.n_size)
@property
def full(self):
return self.size >= self.n_size
[docs]
def can_sample(self):
return self.size >= self.batch_size
[docs]
def clear(self, *args):
self.memory = deque(maxlen=self.n_size)
[docs]
def store(self, episode):
self.memory.append(episode)
self.ptr = (self.ptr + 1) % self.n_size
self.size = min(self.size + 1, self.n_size)
[docs]
def sample(self):
obs_batch, act_batch, rew_batch, terminal_batch = [], [], [], []
episode_choices = np.random.choice(self.memory, self.batch_size)
length_min = self.episode_length
for episode in episode_choices:
length_min = min(length_min, len(episode))
if length_min > self.lookup_length:
for episode in episode_choices:
start_idx = np.random.randint(0, len(episode) - self.lookup_length + 1)
sampled_data = episode.sample(lookup_step=self.lookup_length, idx=start_idx)
obs_batch.append(sampled_data["obs"])
act_batch.append(sampled_data["acts"])
rew_batch.append(sampled_data["rews"])
terminal_batch.append(sampled_data["done"])
else:
for episode in episode_choices:
start_idx = np.random.randint(0, len(episode) - length_min + 1)
sampled_data = episode.sample(lookup_step=length_min, idx=start_idx)
obs_batch.append(sampled_data["obs"])
act_batch.append(sampled_data["acts"])
rew_batch.append(sampled_data["rews"])
terminal_batch.append(sampled_data["done"])
samples_dict = {
'obs': np.array(obs_batch),
'actions': np.array(act_batch),
'rewards': np.array(rew_batch),
'terminals': np.array(terminal_batch),
'batch_size': self.batch_size,
}
return samples_dict
[docs]
class PerOffPolicyBuffer(Buffer):
"""
Prioritized Replay Buffer.
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
buffer_size: the total size of the replay buffer.
batch_size: batch size of transition data for a sample.
alpha: prioritized factor.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
buffer_size: int,
batch_size: int,
alpha: float = 0.6):
super(PerOffPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, buffer_size)
self.batch_size = batch_size
assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)"
self.n_size = buffer_size // self.n_envs
self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.rewards = create_memory((), self.n_envs, self.n_size)
self.terminals = create_memory((), self.n_envs, self.n_size)
self._alpha = alpha
# set segment tree size
it_capacity = 1
while it_capacity < self.n_size:
it_capacity *= 2
# init segment tree
self._it_sum = []
self._it_min = []
for _ in range(n_envs):
self._it_sum.append(SumSegmentTree(it_capacity))
self._it_min.append(MinSegmentTree(it_capacity))
self._max_priority = np.ones((n_envs))
def _sample_proportional(self, env_idx, batch_size):
res = []
p_total = self._it_sum[env_idx].sum(0, self.size - 1)
every_range_len = p_total / batch_size
for i in range(batch_size):
mass = random.random() * every_range_len + i * every_range_len
idx = self._it_sum[env_idx].find_prefixsum_idx(mass)
res.append(int(idx))
return res
[docs]
def clear(self):
self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.rewards = create_memory((), self.n_envs, self.n_size)
self.terminals = create_memory((), self.n_envs, self.n_size)
self._it_sum = []
self._it_min = []
[docs]
def store(self, obs, acts, rews, terminals, next_obs):
store_element(obs, self.observations, self.ptr)
store_element(acts, self.actions, self.ptr)
store_element(rews, self.rewards, self.ptr)
store_element(terminals, self.terminals, self.ptr)
store_element(next_obs, self.next_observations, self.ptr)
# prioritized process
for i in range(self.n_envs):
self._it_sum[i][self.ptr] = self._max_priority[i] ** self._alpha
self._it_min[i][self.ptr] = self._max_priority[i] ** self._alpha
self.ptr = (self.ptr + 1) % self.n_size
self.size = min(self.size + 1, self.n_size)
[docs]
def sample(self, beta):
env_choices = np.array(range(self.n_envs)).repeat(int(self.batch_size / self.n_envs))
step_choices = np.zeros((self.n_envs, int(self.batch_size / self.n_envs)))
weights = np.zeros((self.n_envs, int(self.batch_size / self.n_envs)))
assert beta > 0
for i in range(self.n_envs):
idxes = self._sample_proportional(i, int(self.batch_size / self.n_envs))
weights_ = []
p_min = self._it_min[i].min() / self._it_sum[i].sum()
max_weight = p_min * self.size ** (-beta)
for idx in idxes:
p_sample = self._it_sum[i][idx] / self._it_sum[i].sum()
weight = p_sample * self.size ** (-beta)
weights_.append(weight / max_weight)
step_choices[i] = idxes
weights[i] = np.array(weights_)
step_choices = step_choices.astype(np.int64)
samples_dict = {
'obs': sample_batch(self.observations, tuple([env_choices, step_choices.flatten()])),
'actions': sample_batch(self.actions, tuple([env_choices, step_choices.flatten()])),
'obs_next': sample_batch(self.next_observations, tuple([env_choices, step_choices.flatten()])),
'rewards': sample_batch(self.rewards, tuple([env_choices, step_choices.flatten()])),
'terminals': sample_batch(self.terminals, tuple([env_choices, step_choices.flatten()])),
'weights': weights,
'step_choices': step_choices,
'batch_size': self.batch_size,
}
# return tuple(list(encoded_sample) + [weights, idxes])
return samples_dict
[docs]
def update_priorities(self, idxes, priorities):
priorities = priorities.reshape((self.n_envs, int(self.batch_size / self.n_envs)))
for i in range(self.n_envs):
for idx, priority in zip(idxes[i], priorities[i]):
if priority == 0:
priority += 1e-8
assert 0 <= idx < self.size
self._it_sum[i][idx] = priority ** self._alpha
self._it_min[i][idx] = priority ** self._alpha
self._max_priority[i] = max(self._max_priority[i], priority)
[docs]
class DummyOffPolicyBuffer_Atari(DummyOffPolicyBuffer):
"""
Replay buffer for off-policy DRL algorithms and Atari tasks.
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
buffer_size: the total size of the replay buffer.
batch_size: batch size of transition data for a sample.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
buffer_size: int,
batch_size: int):
super(DummyOffPolicyBuffer_Atari, self).__init__(observation_space, action_space, auxiliary_shape,
n_envs, buffer_size, batch_size)
[docs]
def clear(self):
self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size, np.uint8)
self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size, np.uint8)
self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size)
self.rewards = create_memory((), self.n_envs, self.n_size)
self.terminals = create_memory((), self.n_envs, self.n_size)
[docs]
class SequentialReplayBuffer(Buffer):
"""
Sequential Replay buffer for Dreamerv3
Args:
observation_space: the observation space of the environment.
action_space: the action space of the environment.
auxiliary_shape: data shape of auxiliary information (if exists).
n_envs: number of parallel environments.
buffer_size: the total size of the replay buffer.
batch_size: size of transition data for a batch of sample.
"""
def __init__(self,
observation_space: Space,
action_space: Space,
auxiliary_shape: Optional[dict],
n_envs: int,
buffer_size: int,
batch_size: int):
super(SequentialReplayBuffer, self).__init__(observation_space, action_space, auxiliary_shape)
self.batch_size = batch_size
assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)"
self.n_size = buffer_size // self.n_envs
self.obs = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.acts = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.rews = create_memory((), self.n_envs, self.n_size)
self.terms = create_memory((), self.n_envs, self.n_size)
self.truncs = create_memory((), self.n_envs, self.n_size)
self.is_first = create_memory((), self.n_envs, self.n_size)
[docs]
def clear(self):
self.obs = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size)
self.acts = create_memory(space2shape(self.action_space), self.n_envs, self.n_size)
self.rews = create_memory((), self.n_envs, self.n_size)
self.terms = create_memory((), self.n_envs, self.n_size)
self.truncs = create_memory((), self.n_envs, self.n_size)
self.is_first = create_memory((), self.n_envs, self.n_size)
[docs]
def store(self, obs, acts, rews, terms, truncs, is_first):
"""
Args:
all arguments are numpy arrays, shape: [envs, ~] if ~ != 1 else [envs, ]
Returns:
"""
store_element(obs, self.obs, self.ptr)
store_element(acts, self.acts, self.ptr)
store_element(rews, self.rews, self.ptr)
store_element(terms, self.terms, self.ptr)
store_element(truncs, self.truncs, self.ptr)
store_element(is_first, self.is_first, self.ptr)
self.ptr = (self.ptr + 1) % self.n_size
self.size = min(self.size + 1, self.n_size)
[docs]
def sample(self, seq_len: int):
"""
Sample elements from the replay buffer in a sequential manner, without considering the episode boundaries.
Args:
seq_len (int)
Returns:
Dict[str, np.ndarray]: the sampled dictionary with a shape of
[envs, sequence_length, batch_size, ...].
"""
# [self.ptr, (self.ptr + sequence_length) % self.n_size)
# make sure there are more than seq_len of data stored
assert self.size - seq_len >= 0
first_range_end = self.ptr - seq_len + 1
indices = np.arange(0, first_range_end)
if self.size == self.n_size:
second_range_end = self.ptr + seq_len if first_range_end <= 0 else self.n_size
indices = np.concatenate([
indices,
np.arange(self.ptr, second_range_end)
])
li = []
for _ in range(self.n_envs):
start = np.random.choice(indices, size=self.batch_size).reshape(-1, 1) # (batch, 1)
seq_arange = np.arange(seq_len, dtype=np.intp).reshape(1, -1) # (1, seq)
idxes = (start + seq_arange) % self.n_size # (batch, seq)
li.append(np.swapaxes(idxes, 0, 1)) # (seq, batch)
idxes = np.stack(li) # (envs, seq, batch)
envs = np.broadcast_to(np.arange(self.n_envs)[:, None, None], idxes.shape) # (env, seq, batch)
envs, idxes = envs.ravel(), idxes.ravel()
samples_dict = { # (envs, seq, batch, ~)
'obs': self.obs[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size,
*space2shape(self.observation_space)),
'acts': self.acts[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size,
*space2shape(self.action_space)),
'rews': self.rews[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1),
'terms': self.terms[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1),
'truncs': self.truncs[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1),
'is_first': self.is_first[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1),
}
return samples_dict