diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index a443c1899..ab25a21fe 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -445,7 +445,10 @@ def test_add_frame_audio(audio_dataset): dataset.save_episode() assert dataset[0]["audio"].shape == torch.Size( - (int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), DUMMY_AUDIO_CHANNELS) + ( + DUMMY_AUDIO_CHANNELS, + int(DEFAULT_AUDIO_CHUNK_DURATION * DEFAULT_SAMPLE_RATE), + ) # Match pytorch channel-first format )