Add review feedback

This commit is contained in:
AdilZouitine
2025-05-16 14:25:21 +02:00
parent fa72aed5b6
commit 1df2a7b2da
4 changed files with 26 additions and 9 deletions

View File

@@ -271,9 +271,11 @@ class ReplayBuffer:
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# State images are at even indices (0, 2, 4...)
# Calculate offsets for the current image key:
# For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states)
# States start at index i*2*batch_size and take up batch_size slots
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
# Next state images are at odd indices (1, 3, 5...)
# Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors

View File

@@ -111,7 +111,9 @@ def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes:
def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]:
buffer = io.BytesIO(buffer)
buffer.seek(0)
return torch.load(buffer, weights_only=False) # nosec B614: Safe usage of torch.load
return torch.load(buffer, weights_only=False) # nosec B614: Using weights_only=False relies on pickle which has security implications.
# This is currently safe as we only deserialize trusted internal data.
# TODO: Verify if weights_only=True would work for our use case (safer default in torch 2.6+)
def python_object_to_bytes(python_object: Any) -> bytes: