0

I have created a colab (link: https://colab.research.google.com/drive/1gg57PS7KMLKvvx9wgDKLMDiyjhRICplp#scrollTo=zG2D7JO2OdEC) to play with the gpt2 fine-tuning.

And I was trying to practice the DDP in pytorch on a GPU based Google colab kernel, the kernel information is as below:

Number of GPUs available: 1 --- GPU 0 --- Device Name: Tesla T4 Total Memory: 14.74 GB Python 3 Google Compute Engine backend (GPU)

To simplify the debugging, I am creating a very simple worker function as below:

# --- 2. Setup and Cleanup Functions for Distributed Training ---
def setup_distributed(rank, world_size, backend='nccl'):
    """Initializes the distributed process group."""
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355' # An unused port

    # Initialize the process group
    print(f"Rank {rank}/{world_size}: Initializing process group with backend '{backend}'...")
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    print(f"Rank {rank}/{world_size}: Process group initialized.")

def cleanup_distributed():
    """Cleans up the distributed process group."""
    if dist.is_initialized():
        current_rank = dist.get_rank()
        dist.destroy_process_group()
        print(f"Rank {current_rank}: Cleaned up process group.")
    else:
        print("Process group was not initialized or already destroyed.")

    def worker_function_1(rank, world_size):



def worker_function_1(rank, world_size):
  try:
    print('enter worker function')
    backend = 'nccl'
    setup_distributed(rank, world_size, backend=backend)
    print(f"Rank {rank}/{world_size}: Process group initialized.")
    cleanup_distributed()
    print(f"---- Worker rank {rank} finished. ----")
  except Exception as e:
    print(f"ERROR during worker_function: {e}")
    traceback.print_exc()
    raise e

Then I wrote the following main blocker to execute the worker function with DDP:

import traceback
if __name__ == '__main__': # This is crucial for multiprocessing to work correctly
    print("--- Starting DDP Demo for Google Colab ---")

    # Determine world size and if GPUs should be used
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs available in Colab: {num_gpus}")

    use_gpu_for_ddp = True # Try to use GPU by default

    if num_gpus == 0:
        print("No GPUs detected by PyTorch. Running DDP on CPU.")
        world_size = 2  # Example: Run 2 processes on CPU
        use_gpu_for_ddp = False
    elif num_gpus == 1:
        print("1 GPU detected. Will run DDP with world_size=1 on this GPU.")
        print("(This tests DDP logic but won't provide speedup over non-DDP operation.)")
        world_size = 1
        # Alternatively, to test multi-process logic on CPU even with 1 GPU:
        # print("1 GPU detected, but will run DDP on 2 CPU processes for demonstration.")
        # world_size = 2
        # use_gpu_for_ddp = False
    else: # Multiple GPUs available (rare in free Colab, possible in Pro)
        print(f"{num_gpus} GPUs detected. Running DDP with world_size={num_gpus}.")
        world_size = num_gpus

    print(f"Target world_size: {world_size}, Using GPU: {use_gpu_for_ddp}")
    print("Spawning DDP processes...")

    # mp.spawn will create 'world_size' processes, each running 'worker_function'.
    # The 'rank' argument is automatically passed by mp.spawn (from 0 to world_size-1).
    try:
        mp.set_start_method('spawn', force=True) # Good practice for spawn
        mp.spawn(worker_function_1,
                 args=(world_size,), # Arguments for worker_function after rank
                 nprocs=world_size,        # Number of processes to spawn
                 join=True)               # Wait for all processes to finish
    except Exception as e:
        print(f"ERROR during mp.spawn: {e}")
        traceback.print_exc()
        raise e


    print("--- DDP Demo Finished ---")

However, I got the following error:

--- Starting DDP Demo for Google Colab ---
Number of GPUs available in Colab: 1
1 GPU detected. Will run DDP with world_size=1 on this GPU.
(This tests DDP logic but won't provide speedup over non-DDP operation.)
Target world_size: 1, Using GPU: True
Spawning DDP processes...
ERROR during mp.spawn: process 0 terminated with exit code 1
Traceback (most recent call last):
  File "<ipython-input-74-b21bda05768f>", line 34, in <cell line: 0>
    mp.spawn(worker_function_1,
  File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 340, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 296, in start_processes
    while not context.join():
              ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py", line 204, in join
    raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with exit code 1
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
<ipython-input-74-b21bda05768f> in <cell line: 0>()
     39         print(f"ERROR during mp.spawn: {e}")
     40         traceback.print_exc()
---> 41         raise e
     42 
     43 

3 frames
<ipython-input-74-b21bda05768f> in <cell line: 0>()
     32     try:
     33         mp.set_start_method('spawn', force=True) # Good practice for spawn
---> 34         mp.spawn(worker_function_1,
     35                  args=(world_size,), # Arguments for worker_function after rank
     36                  nprocs=world_size,        # Number of processes to spawn

/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
    338         )
    339         warnings.warn(msg, FutureWarning, stacklevel=2)
--> 340     return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")

/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    294 
    295     # Loop on join until it returns True or raises an exception.
--> 296     while not context.join():
    297         pass
    298 

/usr/local/lib/python3.11/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout, grace_period)
    202                 )
    203             else:
--> 204                 raise ProcessExitedException(
    205                     "process %d terminated with exit code %d" % (error_index, exitcode),
    206                     error_index=error_index,

ProcessExitedException: process 0 terminated with exit code 1

Can someone help to figure out the root cause?

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.