mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
Clean the code
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import io
|
||||
import pickle # nosec B403: Safe usage of pickle
|
||||
@@ -194,6 +195,10 @@ class ReplayBuffer:
|
||||
optimize_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Replay buffer for storing transitions.
|
||||
It will allocate tensors on the specified device, when the first transition is added.
|
||||
NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or
|
||||
and use the `storage_device` flag to store the buffer on a different device.
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu").
|
||||
@@ -368,7 +373,7 @@ class ReplayBuffer:
|
||||
all_images.append(batch_state[key])
|
||||
all_images.append(batch_next_state[key])
|
||||
|
||||
# Batch all images and apply augmentation once
|
||||
# Optimization: Batch all images and apply augmentation once
|
||||
all_images_tensor = torch.cat(all_images, dim=0)
|
||||
augmented_images = self.image_augmentation_function(all_images_tensor)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user