Discussion
HuggingFace accelerate's init_empty_weights() properly loads all text encoders I tested to the PyTorch meta device and consumes no apparent memory or disk space while loaded.
However, it worked differently with both HuggingFace diffusers I tried (Flux and Stable Diffusion XL). They loaded to the either the "CPU" or "CUDA" devices and caused memory to be consumed as apparent through the Windows 11 Performance Manager.
Is init_empty_weights() implemented differently, incorrectly, or not at all for diffusers as compared with text encoders?
Code:
init_empty_weights() Works For Text Encoder
with init_empty_weights():
text_encoder_2 = T5EncoderModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="text_encoder_2",
torch_dtype=torch.float32
)
text_encoder_2.device
Jupyter Notebook Response:
device(type='meta')
As expected, the model is loaded only to the meta device and Windows 11 Performance Monitor shows no additional RAM or VRAM usage.
init_empty_weights() Doesn't Seem to Work For Diffusers
init_empty_weights() Doesn't Seem to Work For Flux
with init_empty_weights():
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
torch_dtype=torch.bfloat16
)
transformer.device
Jupyter Notebook Response:
device(type='cpu')
Unexpectedly (to me), the model was loaded to CPU (vice meta) and Windows 11 Performance Monitor shows the corresponding increase in RAM usage.
init_empty_weights() Doesn't Seem to Work For SDXL
with init_empty_weights():
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
pipeline.unet.device
Jupyter Notebook Response:
device(type='cpu')
Unexpectedly (to me), the model was loaded to CPU (vice meta) and Windows 11 Performance Monitor shows the corresponding increase in RAM usage.
Background
If it's helpful to question answerers, I ask because I want to initialize models with empty weights in order to pass them to HuggingFace accelerate infer_auto_device_map(), allowing accelerate to make a best guess as to which device the various model layers should be loaded on. Loading the full models merely to obtain their shape is slow. It is possible (although inconvenient) to load a full model, obtain its inferred device map, output a text representation of that device map to text, restart the python kernel, assign the output device map text to a new device map, and finally use the new device map when loading the model for a second time. An awkward workaround.