A100 makes too large check point and slow learning

System Info

transformers 4.36.0
Python 3.8.10
Driver Version: 535.104.12
CUDA Version: 12.2

I’ve run a Audio classification code in 4090, H100, A100.
In 4090, H100 Code ran well and not large check point , also speed was fast

When it comes to run in 4*A100,
It made too large check point and was too slow ; I thouth it is due to No pararell process code
So , I made a container with a A100 GPU
But it caused also same problems

Here is my code and data link

Can you help me using A100?

Data
https://drive.google.com/file/d/1tKNgHiy-b9_oL8hWG4vpDKePqAv8GHNc/view?usp=drive_link
Code
https://drive.google.com/file/d/1zU0UziwtI8SN7PJD35E7-Tg1NSKnYbSr/view?usp=drive_link