Source code for xuance.common.memory_tools

import random
import numpy as np
from gymnasium import Space
from abc import ABC, abstractmethod
from typing import Optional, Union, Dict
from collections import deque
from xuance.common.common_tools import discount_cumsum
from xuance.common.segtree_tool import SumSegmentTree, MinSegmentTree
from xuance.environment.utils import space2shape


[docs] def create_memory(shape: Optional[Union[tuple, dict]], n_envs: int, n_size: int, dtype: type = np.float32): """ Create a numpy array for memory data. Args: shape: data shape. n_envs: number of parallel environments. n_size: length of data sequence for each environment. dtype: numpy data type. Returns: An empty memory space to store data. (initial: numpy.zeros()) """ if shape is None: return None elif isinstance(shape, dict): memory = {} for key, value in shape.items(): if value is None: # save an object type memory[key] = np.zeros([n_envs, n_size], dtype=object) else: memory[key] = np.zeros([n_envs, n_size] + list(value), dtype=dtype) return memory elif isinstance(shape, tuple): return np.zeros([n_envs, n_size] + list(shape), dtype) else: raise NotImplementedError
[docs] def store_element(data: Optional[Union[np.ndarray, dict, float]], memory: Union[dict, np.ndarray], ptr: int): """ Insert a step of data into current memory. Args: data: target data that to be stored. memory: the memory where data will be stored. ptr: pointer to the location for the data. """ if data is None: return elif isinstance(data, dict): for key, value in data.items(): memory[key][:, ptr] = data[key] else: memory[:, ptr] = data
[docs] def sample_batch(memory: Optional[Union[np.ndarray, dict]], index: Optional[Union[np.ndarray, tuple]]): """ Sample a batch of data from the selected memory. Args: memory: memory that contains experience data. index: pointer to the location for the selected data. Returns: A batch of data. """ if memory is None: return None elif isinstance(memory, dict): batch = {} for key, value in memory.items(): batch[key] = value[index] return batch else: return memory[index]
[docs] class Buffer(ABC): """ Basic buffer single-agent DRL algorithms. Args: observation_space: the space for observation data. action_space: the space for action data. auxiliary_info_shape: the shape for auxiliary data if needed. """ def __init__( self, observation_space: Space, action_space: Space, auxiliary_info_shape: Optional[dict], num_envs: int, buffer_size: int ): assert buffer_size % num_envs == 0, "buffer_size must be divisible by the number of envs (parallels)" self.observation_space = observation_space self.action_space = action_space self.auxiliary_shape = auxiliary_info_shape # Pre-define the data that might be stored in replay buffer for training. self.observations: Optional[np.ndarray] = None self.next_observations: Optional[np.ndarray] = None self.actions: Optional[np.ndarray] = None self.auxiliary_infos: Optional[np.ndarray, dict] = None self.rewards: Optional[np.ndarray] = None self.terminals: Optional[np.ndarray] = None self.returns: Optional[np.ndarray] = None self.values: Optional[np.ndarray] = None self.n_envs = num_envs self.buffer_size = buffer_size self.n_size = self.buffer_size // self.n_envs self.advantages: Optional[np.ndarray] = None self.ptr = 0 # last data pointer self.size = 0 # current buffer size per environment. @property def full(self): return self.size >= self.n_size
[docs] @abstractmethod def store(self, *args): raise NotImplementedError
[docs] @abstractmethod def clear(self, *args): raise NotImplementedError
[docs] @abstractmethod def sample(self, *args): raise NotImplementedError
[docs] def finish_path(self, *args): pass
[docs] class EpisodeBuffer: """ Episode buffer for DRQN agent. """ def __init__(self): self.obs = [] self.action = [] self.reward = [] self.done = []
[docs] def put(self, transition): self.obs.append(transition[0]) self.action.append(transition[1]) self.reward.append(transition[2]) self.done.append(transition[3])
[docs] def sample(self, lookup_step=None, idx=None) -> Dict[str, np.ndarray]: obs = np.array(self.obs) action = np.array(self.action) reward = np.array(self.reward) done = np.array(self.done) obs = obs[idx:idx + lookup_step + 1] action = action[idx:idx + lookup_step] reward = reward[idx:idx + lookup_step] done = done[idx:idx + lookup_step] return dict(obs=obs, acts=action, rews=reward, done=done)
def __len__(self) -> int: return len(self.action)
[docs] class DummyOnPolicyBuffer(Buffer): """ Replay buffer for on-policy DRL algorithms. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. horizon_size: max length of steps to store for one environment. use_gae: if use GAE trick. use_advnorm: if use Advantage normalization trick. gamma: discount factor. gae_lam: gae lambda. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, horizon_size: int, use_gae: bool = True, use_advnorm: bool = True, gamma: float = 0.99, gae_lam: float = 0.95): self.buffer_size = horizon_size * n_envs super(DummyOnPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, self.buffer_size) self.horizon_size = horizon_size self.n_size = self.horizon_size self.use_gae, self.use_advnorm = use_gae, use_advnorm self.gamma, self.gae_lam = gamma, gae_lam self.start_ids = np.zeros(self.n_envs, np.int64) self.clear() @property def full(self): return self.size >= self.n_size
[docs] def clear(self): self.ptr, self.size = 0, 0 self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.rewards = create_memory((), self.n_envs, self.n_size) self.returns = create_memory((), self.n_envs, self.n_size) self.values = create_memory((), self.n_envs, self.n_size) self.terminals = create_memory((), self.n_envs, self.n_size) self.advantages = create_memory((), self.n_envs, self.n_size) self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size)
[docs] def store(self, obs, acts, rews, value, terminals, aux_info=None): store_element(obs, self.observations, self.ptr) store_element(acts, self.actions, self.ptr) store_element(rews, self.rewards, self.ptr) store_element(value, self.values, self.ptr) store_element(terminals, self.terminals, self.ptr) store_element(aux_info, self.auxiliary_infos, self.ptr) self.ptr = (self.ptr + 1) % self.n_size self.size = min(self.size + 1, self.n_size)
[docs] def finish_path(self, val, i): if self.full: path_slice = np.arange(self.start_ids[i], self.n_size).astype(np.int32) else: path_slice = np.arange(self.start_ids[i], self.ptr).astype(np.int32) vs = np.append(np.array(self.values[i, path_slice]), [val], axis=0) if self.use_gae: # use gae rewards = np.array(self.rewards[i, path_slice]) advantages = np.zeros_like(rewards) dones = np.array(self.terminals[i, path_slice]) last_gae_lam = 0 step_nums = len(path_slice) for t in reversed(range(step_nums)): delta = rewards[t] + (1 - dones[t]) * self.gamma * vs[t + 1] - vs[t] advantages[t] = last_gae_lam = delta + (1 - dones[t]) * self.gamma * self.gae_lam * last_gae_lam returns = advantages + vs[:-1] else: rewards = np.append(np.array(self.rewards[i, path_slice]), [val], axis=0) returns = discount_cumsum(rewards, self.gamma)[:-1] advantages = rewards[:-1] + self.gamma * vs[1:] - vs[:-1] self.returns[i, path_slice] = returns self.advantages[i, path_slice] = advantages self.start_ids[i] = self.ptr
[docs] def sample(self, indexes): assert self.full, "Not enough transitions for on-policy buffer to random sample" env_choices, step_choices = divmod(indexes, self.n_size) samples_dict = { 'obs': sample_batch(self.observations, tuple([env_choices, step_choices])), 'actions': sample_batch(self.actions, tuple([env_choices, step_choices])), 'returns': sample_batch(self.returns, tuple([env_choices, step_choices])), 'values': sample_batch(self.values, tuple([env_choices, step_choices])), 'aux_batch': sample_batch(self.auxiliary_infos, tuple([env_choices, step_choices])), 'batch_size': len(indexes), } adv_batch = sample_batch(self.advantages, tuple([env_choices, step_choices])) if self.use_advnorm: adv_batch = (adv_batch - np.mean(adv_batch)) / (np.std(adv_batch) + 1e-8) samples_dict.update({ 'advantages': adv_batch }) return samples_dict
[docs] class DummyOnPolicyBuffer_Atari(DummyOnPolicyBuffer): """ Replay buffer for on-policy DRL algorithms and Atari tasks. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. horizon_size: max length of steps to store for one environment. use_gae: if use GAE trick. use_advnorm: if use Advantage normalization trick. gamma: discount factor. gae_lam: gae lambda. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, horizon_size: int, use_gae: bool = True, use_advnorm: bool = True, gamma: float = 0.99, gae_lam: float = 0.95): super(DummyOnPolicyBuffer_Atari, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, horizon_size, use_gae, use_advnorm, gamma, gae_lam)
[docs] def clear(self): self.ptr, self.size = 0, 0 self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size, np.uint8) self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.rewards = create_memory((), self.n_envs, self.n_size) self.returns = create_memory((), self.n_envs, self.n_size) self.values = create_memory((), self.n_envs, self.n_size) self.terminals = create_memory((), self.n_envs, self.n_size) self.advantages = create_memory((), self.n_envs, self.n_size) self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size)
[docs] class DummyOffPolicyBuffer(Buffer): """ Replay buffer for off-policy DRL algorithms. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. buffer_size: the total size of the replay buffer. batch_size: size of transition data for a batch of sample. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, buffer_size: int, batch_size: int): super(DummyOffPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, buffer_size) self.batch_size = batch_size assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)" self.n_size = buffer_size // self.n_envs self.clear()
[docs] def clear(self): self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size) self.rewards = create_memory((), self.n_envs, self.n_size) self.terminals = create_memory((), self.n_envs, self.n_size)
[docs] def store(self, obs, acts, rews, terminals, next_obs): store_element(obs, self.observations, self.ptr) store_element(acts, self.actions, self.ptr) store_element(rews, self.rewards, self.ptr) store_element(terminals, self.terminals, self.ptr) store_element(next_obs, self.next_observations, self.ptr) self.ptr = (self.ptr + 1) % self.n_size self.size = min(self.size + 1, self.n_size)
[docs] def sample(self, batch_size=None): bs = self.batch_size if batch_size is None else batch_size env_choices = np.random.choice(self.n_envs, bs) step_choices = np.random.choice(self.size, bs) samples_dict = { 'obs': sample_batch(self.observations, tuple([env_choices, step_choices])), 'actions': sample_batch(self.actions, tuple([env_choices, step_choices])), 'obs_next': sample_batch(self.next_observations, tuple([env_choices, step_choices])), 'rewards': sample_batch(self.rewards, tuple([env_choices, step_choices])), 'terminals': sample_batch(self.terminals, tuple([env_choices, step_choices])), 'batch_size': bs, } return samples_dict
[docs] class RecurrentOffPolicyBuffer(Buffer): """ Replay buffer for DRQN-based algorithms. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. buffer_size: the size of replay buffer that stores episodes of data. batch_size: batch size of transition data for a sample. episode_length: data length for an episode. lookup_length: the length of history data. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, buffer_size: int, batch_size: int, episode_length: int, lookup_length: int): super(RecurrentOffPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, buffer_size) self.episode_length, self.batch_size = episode_length, batch_size assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)" self.n_size = self.buffer_size // self.n_envs self.lookup_length = lookup_length self.memory = deque(maxlen=self.n_size) @property def full(self): return self.size >= self.n_size
[docs] def can_sample(self): return self.size >= self.batch_size
[docs] def clear(self, *args): self.memory = deque(maxlen=self.n_size)
[docs] def store(self, episode): self.memory.append(episode) self.ptr = (self.ptr + 1) % self.n_size self.size = min(self.size + 1, self.n_size)
[docs] def sample(self): obs_batch, act_batch, rew_batch, terminal_batch = [], [], [], [] episode_choices = np.random.choice(self.memory, self.batch_size) length_min = self.episode_length for episode in episode_choices: length_min = min(length_min, len(episode)) if length_min > self.lookup_length: for episode in episode_choices: start_idx = np.random.randint(0, len(episode) - self.lookup_length + 1) sampled_data = episode.sample(lookup_step=self.lookup_length, idx=start_idx) obs_batch.append(sampled_data["obs"]) act_batch.append(sampled_data["acts"]) rew_batch.append(sampled_data["rews"]) terminal_batch.append(sampled_data["done"]) else: for episode in episode_choices: start_idx = np.random.randint(0, len(episode) - length_min + 1) sampled_data = episode.sample(lookup_step=length_min, idx=start_idx) obs_batch.append(sampled_data["obs"]) act_batch.append(sampled_data["acts"]) rew_batch.append(sampled_data["rews"]) terminal_batch.append(sampled_data["done"]) samples_dict = { 'obs': np.array(obs_batch), 'actions': np.array(act_batch), 'rewards': np.array(rew_batch), 'terminals': np.array(terminal_batch), 'batch_size': self.batch_size, } return samples_dict
[docs] class PerOffPolicyBuffer(Buffer): """ Prioritized Replay Buffer. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. buffer_size: the total size of the replay buffer. batch_size: batch size of transition data for a sample. alpha: prioritized factor. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, buffer_size: int, batch_size: int, alpha: float = 0.6): super(PerOffPolicyBuffer, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, buffer_size) self.batch_size = batch_size assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)" self.n_size = buffer_size // self.n_envs self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.rewards = create_memory((), self.n_envs, self.n_size) self.terminals = create_memory((), self.n_envs, self.n_size) self._alpha = alpha # set segment tree size it_capacity = 1 while it_capacity < self.n_size: it_capacity *= 2 # init segment tree self._it_sum = [] self._it_min = [] for _ in range(n_envs): self._it_sum.append(SumSegmentTree(it_capacity)) self._it_min.append(MinSegmentTree(it_capacity)) self._max_priority = np.ones((n_envs)) def _sample_proportional(self, env_idx, batch_size): res = [] p_total = self._it_sum[env_idx].sum(0, self.size - 1) every_range_len = p_total / batch_size for i in range(batch_size): mass = random.random() * every_range_len + i * every_range_len idx = self._it_sum[env_idx].find_prefixsum_idx(mass) res.append(int(idx)) return res
[docs] def clear(self): self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.rewards = create_memory((), self.n_envs, self.n_size) self.terminals = create_memory((), self.n_envs, self.n_size) self._it_sum = [] self._it_min = []
[docs] def store(self, obs, acts, rews, terminals, next_obs): store_element(obs, self.observations, self.ptr) store_element(acts, self.actions, self.ptr) store_element(rews, self.rewards, self.ptr) store_element(terminals, self.terminals, self.ptr) store_element(next_obs, self.next_observations, self.ptr) # prioritized process for i in range(self.n_envs): self._it_sum[i][self.ptr] = self._max_priority[i] ** self._alpha self._it_min[i][self.ptr] = self._max_priority[i] ** self._alpha self.ptr = (self.ptr + 1) % self.n_size self.size = min(self.size + 1, self.n_size)
[docs] def sample(self, beta): env_choices = np.array(range(self.n_envs)).repeat(int(self.batch_size / self.n_envs)) step_choices = np.zeros((self.n_envs, int(self.batch_size / self.n_envs))) weights = np.zeros((self.n_envs, int(self.batch_size / self.n_envs))) assert beta > 0 for i in range(self.n_envs): idxes = self._sample_proportional(i, int(self.batch_size / self.n_envs)) weights_ = [] p_min = self._it_min[i].min() / self._it_sum[i].sum() max_weight = p_min * self.size ** (-beta) for idx in idxes: p_sample = self._it_sum[i][idx] / self._it_sum[i].sum() weight = p_sample * self.size ** (-beta) weights_.append(weight / max_weight) step_choices[i] = idxes weights[i] = np.array(weights_) step_choices = step_choices.astype(np.int64) samples_dict = { 'obs': sample_batch(self.observations, tuple([env_choices, step_choices.flatten()])), 'actions': sample_batch(self.actions, tuple([env_choices, step_choices.flatten()])), 'obs_next': sample_batch(self.next_observations, tuple([env_choices, step_choices.flatten()])), 'rewards': sample_batch(self.rewards, tuple([env_choices, step_choices.flatten()])), 'terminals': sample_batch(self.terminals, tuple([env_choices, step_choices.flatten()])), 'weights': weights, 'step_choices': step_choices, 'batch_size': self.batch_size, } # return tuple(list(encoded_sample) + [weights, idxes]) return samples_dict
[docs] def update_priorities(self, idxes, priorities): priorities = priorities.reshape((self.n_envs, int(self.batch_size / self.n_envs))) for i in range(self.n_envs): for idx, priority in zip(idxes[i], priorities[i]): if priority == 0: priority += 1e-8 assert 0 <= idx < self.size self._it_sum[i][idx] = priority ** self._alpha self._it_min[i][idx] = priority ** self._alpha self._max_priority[i] = max(self._max_priority[i], priority)
[docs] class DummyOffPolicyBuffer_Atari(DummyOffPolicyBuffer): """ Replay buffer for off-policy DRL algorithms and Atari tasks. Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. buffer_size: the total size of the replay buffer. batch_size: batch size of transition data for a sample. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, buffer_size: int, batch_size: int): super(DummyOffPolicyBuffer_Atari, self).__init__(observation_space, action_space, auxiliary_shape, n_envs, buffer_size, batch_size)
[docs] def clear(self): self.observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size, np.uint8) self.next_observations = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size, np.uint8) self.actions = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.auxiliary_infos = create_memory(self.auxiliary_shape, self.n_envs, self.n_size) self.rewards = create_memory((), self.n_envs, self.n_size) self.terminals = create_memory((), self.n_envs, self.n_size)
[docs] class SequentialReplayBuffer(Buffer): """ Sequential Replay buffer for Dreamerv3 Args: observation_space: the observation space of the environment. action_space: the action space of the environment. auxiliary_shape: data shape of auxiliary information (if exists). n_envs: number of parallel environments. buffer_size: the total size of the replay buffer. batch_size: size of transition data for a batch of sample. """ def __init__(self, observation_space: Space, action_space: Space, auxiliary_shape: Optional[dict], n_envs: int, buffer_size: int, batch_size: int): super(SequentialReplayBuffer, self).__init__(observation_space, action_space, auxiliary_shape) self.batch_size = batch_size assert buffer_size % self.n_envs == 0, "buffer_size must be divisible by the number of envs (parallels)" self.n_size = buffer_size // self.n_envs self.obs = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.acts = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.rews = create_memory((), self.n_envs, self.n_size) self.terms = create_memory((), self.n_envs, self.n_size) self.truncs = create_memory((), self.n_envs, self.n_size) self.is_first = create_memory((), self.n_envs, self.n_size)
[docs] def clear(self): self.obs = create_memory(space2shape(self.observation_space), self.n_envs, self.n_size) self.acts = create_memory(space2shape(self.action_space), self.n_envs, self.n_size) self.rews = create_memory((), self.n_envs, self.n_size) self.terms = create_memory((), self.n_envs, self.n_size) self.truncs = create_memory((), self.n_envs, self.n_size) self.is_first = create_memory((), self.n_envs, self.n_size)
[docs] def store(self, obs, acts, rews, terms, truncs, is_first): """ Args: all arguments are numpy arrays, shape: [envs, ~] if ~ != 1 else [envs, ] Returns: """ store_element(obs, self.obs, self.ptr) store_element(acts, self.acts, self.ptr) store_element(rews, self.rews, self.ptr) store_element(terms, self.terms, self.ptr) store_element(truncs, self.truncs, self.ptr) store_element(is_first, self.is_first, self.ptr) self.ptr = (self.ptr + 1) % self.n_size self.size = min(self.size + 1, self.n_size)
[docs] def sample(self, seq_len: int): """ Sample elements from the replay buffer in a sequential manner, without considering the episode boundaries. Args: seq_len (int) Returns: Dict[str, np.ndarray]: the sampled dictionary with a shape of [envs, sequence_length, batch_size, ...]. """ # [self.ptr, (self.ptr + sequence_length) % self.n_size) # make sure there are more than seq_len of data stored assert self.size - seq_len >= 0 first_range_end = self.ptr - seq_len + 1 indices = np.arange(0, first_range_end) if self.size == self.n_size: second_range_end = self.ptr + seq_len if first_range_end <= 0 else self.n_size indices = np.concatenate([ indices, np.arange(self.ptr, second_range_end) ]) li = [] for _ in range(self.n_envs): start = np.random.choice(indices, size=self.batch_size).reshape(-1, 1) # (batch, 1) seq_arange = np.arange(seq_len, dtype=np.intp).reshape(1, -1) # (1, seq) idxes = (start + seq_arange) % self.n_size # (batch, seq) li.append(np.swapaxes(idxes, 0, 1)) # (seq, batch) idxes = np.stack(li) # (envs, seq, batch) envs = np.broadcast_to(np.arange(self.n_envs)[:, None, None], idxes.shape) # (env, seq, batch) envs, idxes = envs.ravel(), idxes.ravel() samples_dict = { # (envs, seq, batch, ~) 'obs': self.obs[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, *space2shape(self.observation_space)), 'acts': self.acts[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, *space2shape(self.action_space)), 'rews': self.rews[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1), 'terms': self.terms[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1), 'truncs': self.truncs[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1), 'is_first': self.is_first[envs, idxes].reshape(self.n_envs, seq_len, self.batch_size, 1), } return samples_dict