addition of "is_first" and "is_terminal" for envs

This commit is contained in:
NM512
2023-04-29 07:34:27 +09:00
parent 3d0e2c8b5a
commit 12cccd8475
2 changed files with 11 additions and 4 deletions

View File

@@ -98,7 +98,7 @@ class Atari:
if not self._repeat:
self._buffer[1][:] = self._buffer[0][:]
self._screen(self._buffer[0])
self._done = over or (self._length and self._step >= self._length) or dead
self._done = over or (self._length and self._step >= self._length)
return self._obs(
total,
is_last=self._done or (dead and self._lives == "reset"),
@@ -137,7 +137,12 @@ class Atari:
weights = [0.299, 0.587, 1 - (0.299 + 0.587)]
image = np.tensordot(image, weights, (-1, 0)).astype(image.dtype)
image = image[:, :, None]
return {"image": image, "is_terminal": is_terminal}, reward, is_last, {}
return (
{"image": image, "is_terminal": is_terminal, "is_first": is_first},
reward,
is_last,
{},
)
def _screen(self, array):
self._ale.getScreenRGB2(array)