File size: 16,731 Bytes
5d99e98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 |
# Copyright (c) 2025 Hansheng Chen
import os
from typing import Union, Optional
import torch
import accelerate
import diffusers
from diffusers.models import AutoModel
from diffusers.models.modeling_utils import (
load_state_dict,
_LOW_CPU_MEM_USAGE_DEFAULT,
no_init_weights,
ContextManagers
)
from diffusers.utils import (
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
_get_model_file,
is_accelerate_available,
is_torch_version,
logging,
)
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
from diffusers.quantizers import DiffusersAutoQuantizer
from diffusers.utils.torch_utils import empty_device_cache
from lakonlab.models.architecture.gmflow.gmflux2 import _GMFlux2Transformer2DModel
LOCAL_CLASS_MAPPING = {
"GMFlux2Transformer2DModel": _GMFlux2Transformer2DModel,
}
_SET_ADAPTER_SCALE_FN_MAPPING.update(
_GMFlux2Transformer2DModel=lambda model_cls, weights: weights,
)
logger = logging.get_logger(__name__)
def assign_param(module, tensor_name: str, param: torch.nn.Parameter):
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
module._parameters[tensor_name] = param
class PiFlowMixin:
def load_piflow_adapter(
self,
pretrained_model_name_or_path: Union[str, os.PathLike],
target_module_name: str = "transformer",
adapter_name: Optional[str] = None,
**kwargs
):
r"""
Load a PiFlow adapter from a pretrained model repository into the target module.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`~ModelMixin.save_pretrained`].
target_module_name (`str`, *optional*, defaults to `"transformer"`):
The module name in the model to load the PiFlow adapter into.
adapter_name (`str`, *optional*):
The name to assign to the loaded adapter. If not provided, it defaults to
`"{target_module_name}_piflow"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
variant (`str`, *optional*):
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
Returns:
`str` or `None`: The name assigned to the loaded adapter, or `None` if no LoRA weights were found.
"""
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
disable_mmap = kwargs.pop("disable_mmap", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"diffusers": diffusers.__version__,
"file_type": "model",
"framework": "pytorch",
}
# 1. Determine model class from config
load_config_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"token": token,
"local_files_only": local_files_only,
"revision": revision,
}
config = AutoModel.load_config(pretrained_model_name_or_path, subfolder=subfolder, **load_config_kwargs)
orig_class_name = config["_class_name"]
if orig_class_name in LOCAL_CLASS_MAPPING:
model_cls = LOCAL_CLASS_MAPPING[orig_class_name]
else:
load_config_kwargs.update({"subfolder": subfolder})
from diffusers.pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
model_cls, _ = get_class_obj_and_candidates(
library_name="diffusers",
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)
if model_cls is None:
raise ValueError(f"Can't find a model linked to {orig_class_name}.")
# 2. Get model file
model_file = None
if use_safetensors:
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
assert model_file is not None, \
f"Could not find adapter weights for {pretrained_model_name_or_path}."
# 3. Initialize model
base_module = getattr(self, target_module_name)
torch_dtype = base_module.dtype
device = base_module.device
dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)
# load the state dict early to determine keep_in_fp32_modules
#######################################
overwrite_state_dict = dict()
lora_state_dict = dict()
adapter_state_dict = load_state_dict(model_file, disable_mmap=disable_mmap)
for k in adapter_state_dict.keys():
adapter_state_dict[k] = adapter_state_dict[k].to(dtype=torch_dtype, device=device)
if "lora" in k:
lora_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
else:
overwrite_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
# determine initial quantization config.
#######################################
pre_quantized = ("quantization_config" in base_module.config
and base_module.config["quantization_config"] is not None)
if pre_quantized:
config["quantization_config"] = base_module.config.quantization_config
hf_quantizer = DiffusersAutoQuantizer.from_config(
config["quantization_config"], pre_quantized=True
)
hf_quantizer.validate_environment(torch_dtype=torch_dtype)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
elif not low_cpu_mem_usage:
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
else:
hf_quantizer = None
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = model_cls._keep_in_fp32_modules is not None and (
hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = model_cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
elif not low_cpu_mem_usage:
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
else:
keep_in_fp32_modules = []
# append modules in overwrite_state_dict to keep_in_fp32_modules
for k in overwrite_state_dict.keys():
module_name = k.rsplit('.', 1)[0]
if module_name and module_name not in keep_in_fp32_modules:
keep_in_fp32_modules.append(module_name)
init_contexts = [no_init_weights()]
if low_cpu_mem_usage:
init_contexts.append(accelerate.init_empty_weights())
with ContextManagers(init_contexts):
piflow_module = model_cls.from_config(config).eval()
torch.set_default_dtype(dtype_orig)
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=piflow_module, device_map=None, keep_in_fp32_modules=keep_in_fp32_modules
)
# 4. Load model weights
base_state_dict = base_module.state_dict()
base_state_dict.update(overwrite_state_dict)
empty_state_dict = piflow_module.state_dict()
for param_name, param in base_state_dict.items():
if param_name not in empty_state_dict:
continue
if hf_quantizer is not None and (
hf_quantizer.check_if_quantized_param(
piflow_module, param, param_name, base_state_dict, param_device=device)):
hf_quantizer.create_quantized_param(
piflow_module, param, param_name, device, base_state_dict, dtype=torch_dtype
)
else:
assign_param(piflow_module, param_name, param)
empty_device_cache()
if hf_quantizer is not None:
hf_quantizer.postprocess_model(piflow_module)
piflow_module.hf_quantizer = hf_quantizer
if len(lora_state_dict) == 0:
adapter_name = None
else:
if adapter_name is None:
adapter_name = f"{target_module_name}_piflow"
piflow_module.load_lora_adapter(
lora_state_dict, prefix=None, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
if adapter_name is None:
logger.warning(
f"No LoRA weights were found in {pretrained_model_name_or_path}."
)
setattr(self, target_module_name, piflow_module)
return adapter_name
def policy_rollout(
self,
x_t_start: torch.Tensor, # (B, C, *, H, W)
sigma_t_start: torch.Tensor,
sigma_t_end: torch.Tensor,
total_substeps: int,
policy,
**kwargs):
assert sigma_t_start.numel() == 1 and sigma_t_end.numel() == 1, \
"Only supports scalar sigma_t_start and sigma_t_end."
raw_t_start = self.scheduler.unwarp_t(
sigma_t_start, **kwargs)
raw_t_end = self.scheduler.unwarp_t(
sigma_t_end, **kwargs)
delta_raw_t = raw_t_start - raw_t_end
num_substeps = (delta_raw_t * total_substeps).round().to(torch.long).clamp(min=1)
substep_size = delta_raw_t / num_substeps
raw_t = raw_t_start
sigma_t = sigma_t_start
x_t = x_t_start
for substep_id in range(num_substeps.item()):
u = policy.pi(x_t, sigma_t)
raw_t_minus = (raw_t - substep_size).clamp(min=0)
sigma_t_minus = self.scheduler.warp_t(raw_t_minus, **kwargs)
x_t_minus = x_t + u * (sigma_t_minus - sigma_t)
x_t = x_t_minus
sigma_t = sigma_t_minus
raw_t = raw_t_minus
x_t_end = x_t
return x_t_end
|