Add review feedback

This commit is contained in:
AdilZouitine
2025-05-16 14:25:21 +02:00
parent fa72aed5b6
commit 1df2a7b2da
4 changed files with 26 additions and 9 deletions

View File

@@ -271,9 +271,11 @@ class ReplayBuffer:
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# State images are at even indices (0, 2, 4...)
# Calculate offsets for the current image key:
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
# States start at index i*2*batch_size and take up batch_size slots
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
# Next state images are at odd indices (1, 3, 5...)
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors