I have a class with methods to simulate sources across 16 detectors using the Gelsa package. In my main script, I call the method generate.sources. I am trying to use multiprocessing to speed up the simulation, ideally simulating one detector per CPU (I'm working on a cluster). I have allocated 16 CPUs on the cluster, but I notice no speed-up when simulating all 16 detectors in parallel compared to simulating a single detector. The functions I’m using for parallelization are poolcontext, process_detector_wrapper, and run_parallel. Can you help me understand why I’m not seeing any improvement? Thank you.
n=10
import numpy as np
from gelsa import Gelsa
from gelsa import galaxy
from astropy.io import fits
import os
from multiprocessing import Pool
from itertools import product
from contextlib import contextmanager
@contextmanager
def poolcontext(*args, **kwargs):
pool = Pool(*args, **kwargs)
yield pool
pool.terminate()
class Spectra_Generator:
"""
- "pointings" is the table where the pointings are described
- "data" is the Astropy Table where the data are, after adding the AXIS_RATIO column
- "continuum wavelength" is the array of wavelengths you want in the spectra
- "pt_n" is the pointing index in "pointings", not the ID. (18 for 30477)
- "fluxes" is an array containing the fluxes at each wavelength for each source. You find this with a linear fit in two NISP bands.
- "fit_coeffs" is the (16,2) matrix of coefficients for the SEMI-MAJOR AXIS vs FWHM relation
"""
def __init__(self, pointings, data, continuum_wavelength, pt_n, fluxes, fit_coeffs):
self.pointings = pointings
self.data = data
self.continuum_wavelength = continuum_wavelength
self.pt_n = pt_n
self.fluxes = fluxes
self.fit_coeffs = fit_coeffs
def find_fwhm(self, data_, det_n):
def fit_line(x, coeffs):
return coeffs[0]*x + coeffs[1]
fwhm_arcsec = fit_line(data_["SEMIMAJOR_AXIS"], coeffs=self.fit_coeffs[det_n]) * 0.1
return fwhm_arcsec
def new_frame(self, G):
new_frame = G.new_spec_frame(ra=self.pointings["Data.AdjustedPointing.RA"][self.pt_n],
dec=self.pointings["Data.AdjustedPointing.Dec"][self.pt_n],
pa=self.pointings["Data.AdjustedPointing.PositionAngle"][self.pt_n],
grism_name=self.pointings["Data.Grism"][self.pt_n],
tilt=self.pointings["Data.GrismWheelTilt"][self.pt_n],
clear=True
)
return new_frame
def process_detector(self, frame, det_n, temp_table, temp_flux, det1, det2):
galaxies = []
mask_detector = (det1 == det_n) | (det2 == det_n) # if no detector is specified then I will generate a full FOV (16 detectors)
print("Simulating {} sources in detector {}\n".format(np.sum(mask_detector), det_n))
temp_table_ondet = temp_table[mask_detector]
temp_flux_ondet = temp_flux[mask_detector]
fwhm_ondet = self.find_fwhm(temp_table_ondet[:n], det_n)
for k in range(len(temp_table_ondet[:n])):
gal = galaxy.Galaxy(
ra=temp_table_ondet[:n]["RIGHT_ASCENSION"][k],
dec=temp_table_ondet[:n]["DECLINATION"][k],
redshift=0,
fwhm_arcsec=fwhm_ondet[:n][k],
axis_ratio=temp_table_ondet[:n]["AXIS_RATIO"][k],
continuum_params=(15000, -1e-5, -17),
obs_wavelength_range=(12000, 19000)
)
continuum_flux = 10**(temp_flux_ondet[:n][k])
gal.set_sed(self.continuum_wavelength, continuum_flux)
galaxies.append(gal)
frame.add_sources(galaxies, noise=False)
return frame
def process_detector_wrapper(self, args):
"""
Multiprocessing object Pool wants a function with a single argument, hence we need this.
"""
self, G, det_id, temp_table, temp_flux, det1, det2 = args
frame = self.new_frame(G)
self.process_detector(frame, det_id, temp_table, temp_flux, det1, det2)
return frame
def run_parallel(self, G, temp_table, temp_flux, det1, det2, det_n=None):
if det_n is not None:
print("Currently simulating sources in detector ", det_n)
frame = self.new_frame(G)
self.process_detector(frame, det_n, temp_table, temp_flux, det1, det2)
return [frame]
else:
print("Currently simulating sources in the entire FOV")
d_range = range(16)
with poolcontext(processes=16) as p:
results = p.map(self.process_detector_wrapper, [(self, G, d, temp_table, temp_flux, det1, det2) for d in d_range])
return results
def generate_sources(self, det_n=None):
"""
Generate spectra for one specific detector (0–15) or the entire FOV (all 16).
"""
if det_n is not None and not (0 <= det_n < 16):
print("Detector number not valid\n")
return
# Initialize Gelsa
G = Gelsa(config_file="/scratch/astro/nicolo.fiaba/gelsa-spectra/calib/gelsa_config.json",
calibdir="/scratch/astro/nicolo.fiaba/gelsa-spectra/calib/",
zero_order_catalog=None
)
frame = self.new_frame(G)
# Select the galaxies within 1 degrees radius from the center of every pointing
galaxies = []
ra_c = self.pointings["Data.AdjustedPointing.RA"][self.pt_n]
dec_c = self.pointings["Data.AdjustedPointing.Dec"][self.pt_n]
selection_mask = np.sqrt((ra_c - self.data["RIGHT_ASCENSION"])**2 +
(dec_c - self.data["DECLINATION"])**2) < 1
print("{} sources in pointing number {} (1 degree from center)".format(sum(selection_mask), self.pt_n))
temp_table = self.data[selection_mask]
temp_flux = self.fluxes[selection_mask]
# Check which galaxies are falling outside the detectors
x1, y1, det1 = frame.radec_to_pixel(temp_table["RIGHT_ASCENSION"], temp_table["DECLINATION"], wavelength=12000)
x2, y2, det2 = frame.radec_to_pixel(temp_table["RIGHT_ASCENSION"], temp_table["DECLINATION"], wavelength=19000)
frames = self.run_parallel(G, temp_table, temp_flux, det1, det2, det_n=det_n)
return frames
def export_images_to_file(self, frames, pt_id, det_n=None):
output_path=f"/scratch/astro/nicolo.fiaba/simulated_images/simulated_images_{pt_id}.fits"
hdus = [fits.PrimaryHDU()]
hdus[0].header["NDET"] = 1 if det_n is not None else 16
if det_n is not None:
im, _, _ = frames[0].get_detector(det_n)
hdu_im = fits.ImageHDU(im.astype(np.float32), name=f"DET{det_n:02d}_IM")
hdus.append(hdu_im)
print(f"Added detector {det_n} to fits file")
else:
for i, f in enumerate(frames):
for d in f._data.keys():
if d == i:
im, _, _ = f.get_detector(d)
hdu_im = fits.ImageHDU(im.astype(np.float32), name=f"DET{d:02d}_IM")
hdus.append(hdu_im)
print(f"Added detector {d} to fits file")
fits.HDUList(hdus).writeto(output_path, overwrite=True)
print(f"\n Saved images to {output_path}")