Clean the code

This commit is contained in:
AdilZouitine
2025-04-24 17:22:54 +02:00
parent b8c2b0bb93
commit a8da4a347e
4 changed files with 56 additions and 39 deletions

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import io
import pickle # nosec B403: Safe usage of pickle
@@ -194,6 +195,10 @@ class ReplayBuffer:
optimize_memory: bool = False,
):
"""
Replay buffer for storing transitions.
It will allocate tensors on the specified device, when the first transition is added.
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
and use the `storage_device` flag to store the buffer on a different device.
Args:
capacity (int): Maximum number of transitions to store in the buffer.
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
@@ -368,7 +373,7 @@ class ReplayBuffer:
all_images.append(batch_state[key])
all_images.append(batch_next_state[key])
# Batch all images and apply augmentation once
# Optimization: Batch all images and apply augmentation once
all_images_tensor = torch.cat(all_images, dim=0)
augmented_images = self.image_augmentation_function(all_images_tensor)