@@ -161,6 +161,8 @@ int main(int argc, char* argv[]) {
161161 MPI_CALL (MPI_Comm_rank (MPI_COMM_WORLD, &rank));
162162 int size;
163163 MPI_CALL (MPI_Comm_size (MPI_COMM_WORLD, &size));
164+ int num_devices = 0 ;
165+ CUDA_RT_CALL (cudaGetDeviceCount (&num_devices));
164166
165167 ncclUniqueId nccl_uid;
166168 if (rank == 0 ) NCCL_CALL (ncclGetUniqueId (&nccl_uid));
@@ -173,17 +175,29 @@ int main(int argc, char* argv[]) {
173175 const bool csv = get_arg (argv, argv + argc, " -csv" );
174176
175177 int local_rank = -1 ;
178+ int local_size = 1 ;
176179 {
177180 MPI_Comm local_comm;
178181 MPI_CALL (MPI_Comm_split_type (MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL,
179182 &local_comm));
180183
181184 MPI_CALL (MPI_Comm_rank (local_comm, &local_rank));
185+ MPI_CALL (MPI_Comm_size (local_comm, &local_size));
182186
183187 MPI_CALL (MPI_Comm_free (&local_comm));
184188 }
185-
186- CUDA_RT_CALL (cudaSetDevice (local_rank));
189+ if ( 1 < num_devices && num_devices < local_size )
190+ {
191+ fprintf (stderr," ERROR Number of visible devices (%d) is less than number of ranks on the node (%d)!\n " , num_devices, local_size);
192+ MPI_CALL (MPI_Finalize ());
193+ return 1 ;
194+ }
195+ if ( 1 == num_devices ) {
196+ // Only 1 device visbile assuming GPU affinity is handled via CUDA_VISIBLE_DEVICES
197+ CUDA_RT_CALL (cudaSetDevice (0 ));
198+ } else {
199+ CUDA_RT_CALL (cudaSetDevice (local_rank));
200+ }
187201 CUDA_RT_CALL (cudaFree (0 ));
188202
189203 ncclComm_t nccl_comm;
0 commit comments