File size: 16,731 Bytes
5d99e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
# Copyright (c) 2025 Hansheng Chen

import os
from typing import Union, Optional

import torch
import accelerate
import diffusers
from diffusers.models import AutoModel
from diffusers.models.modeling_utils import (
    load_state_dict,
    _LOW_CPU_MEM_USAGE_DEFAULT,
    no_init_weights,
    ContextManagers
)
from diffusers.utils import (
    SAFETENSORS_WEIGHTS_NAME,
    WEIGHTS_NAME,
    _add_variant,
    _get_model_file,
    is_accelerate_available,
    is_torch_version,
    logging,
)
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
from diffusers.quantizers import DiffusersAutoQuantizer
from diffusers.utils.torch_utils import empty_device_cache
from lakonlab.models.architecture.gmflow.gmflux2 import _GMFlux2Transformer2DModel


LOCAL_CLASS_MAPPING = {
    "GMFlux2Transformer2DModel": _GMFlux2Transformer2DModel,
}

_SET_ADAPTER_SCALE_FN_MAPPING.update(
    _GMFlux2Transformer2DModel=lambda model_cls, weights: weights,
)

logger = logging.get_logger(__name__)


def assign_param(module, tensor_name: str, param: torch.nn.Parameter):
    if "." in tensor_name:
        splits = tensor_name.split(".")
        for split in splits[:-1]:
            new_module = getattr(module, split)
            if new_module is None:
                raise ValueError(f"{module} has no attribute {split}.")
            module = new_module
        tensor_name = splits[-1]
    module._parameters[tensor_name] = param


class PiFlowMixin:

    def load_piflow_adapter(
        self,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        target_module_name: str = "transformer",
        adapter_name: Optional[str] = None,
        **kwargs
    ):
        r"""
        Load a PiFlow adapter from a pretrained model repository into the target module.

        Args:
            pretrained_model_name_or_path (`str` or `os.PathLike`):
                Can be either:

                    - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
                      the Hub.
                    - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
                      with [`~ModelMixin.save_pretrained`].

            target_module_name (`str`, *optional*, defaults to `"transformer"`):
                The module name in the model to load the PiFlow adapter into.
            adapter_name (`str`, *optional*):
                The name to assign to the loaded adapter. If not provided, it defaults to
                `"{target_module_name}_piflow"`.
            cache_dir (`Union[str, os.PathLike]`, *optional*):
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only(`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
            token (`str` or *bool*, *optional*):
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
            revision (`str`, *optional*, defaults to `"main"`):
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
            subfolder (`str`, *optional*, defaults to `""`):
                The subfolder location of a model file within a larger model repository on the Hub or locally.
            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
                argument to `True` will raise an error.
            variant (`str`, *optional*):
                Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
                loading `from_flax`.
            use_safetensors (`bool`, *optional*, defaults to `None`):
                If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
                `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
                weights. If set to `False`, `safetensors` weights are not loaded.
            disable_mmap ('bool', *optional*, defaults to 'False'):
                Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
                is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.

        Returns:
            `str` or `None`: The name assigned to the loaded adapter, or `None` if no LoRA weights were found.
        """
        cache_dir = kwargs.pop("cache_dir", None)
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
        token = kwargs.pop("token", None)
        local_files_only = kwargs.pop("local_files_only", False)
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)
        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
        variant = kwargs.pop("variant", None)
        use_safetensors = kwargs.pop("use_safetensors", None)
        disable_mmap = kwargs.pop("disable_mmap", False)

        allow_pickle = False
        if use_safetensors is None:
            use_safetensors = True
            allow_pickle = True

        if low_cpu_mem_usage and not is_accelerate_available():
            low_cpu_mem_usage = False
            logger.warning(
                "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                " install accelerate\n```\n."
            )

        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

        user_agent = {
            "diffusers": diffusers.__version__,
            "file_type": "model",
            "framework": "pytorch",
        }

        # 1. Determine model class from config

        load_config_kwargs = {
            "cache_dir": cache_dir,
            "force_download": force_download,
            "proxies": proxies,
            "token": token,
            "local_files_only": local_files_only,
            "revision": revision,
        }

        config = AutoModel.load_config(pretrained_model_name_or_path, subfolder=subfolder, **load_config_kwargs)

        orig_class_name = config["_class_name"]

        if orig_class_name in LOCAL_CLASS_MAPPING:
            model_cls = LOCAL_CLASS_MAPPING[orig_class_name]

        else:
            load_config_kwargs.update({"subfolder": subfolder})

            from diffusers.pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates

            model_cls, _ = get_class_obj_and_candidates(
                library_name="diffusers",
                class_name=orig_class_name,
                importable_classes=ALL_IMPORTABLE_CLASSES,
                pipelines=None,
                is_pipeline_module=False,
            )

        if model_cls is None:
            raise ValueError(f"Can't find a model linked to {orig_class_name}.")

        # 2. Get model file

        model_file = None

        if use_safetensors:
            try:
                model_file = _get_model_file(
                    pretrained_model_name_or_path,
                    weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
                    token=token,
                    revision=revision,
                    subfolder=subfolder,
                    user_agent=user_agent,
                )

            except IOError as e:
                logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
                if not allow_pickle:
                    raise
                logger.warning(
                    "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
                )

        if model_file is None:
            model_file = _get_model_file(
                pretrained_model_name_or_path,
                weights_name=_add_variant(WEIGHTS_NAME, variant),
                cache_dir=cache_dir,
                force_download=force_download,
                proxies=proxies,
                local_files_only=local_files_only,
                token=token,
                revision=revision,
                subfolder=subfolder,
                user_agent=user_agent,
            )

        assert model_file is not None, \
            f"Could not find adapter weights for {pretrained_model_name_or_path}."

        # 3. Initialize model

        base_module = getattr(self, target_module_name)

        torch_dtype = base_module.dtype
        device = base_module.device
        dtype_orig = model_cls._set_default_torch_dtype(torch_dtype)

        # load the state dict early to determine keep_in_fp32_modules
        #######################################
        overwrite_state_dict = dict()
        lora_state_dict = dict()

        adapter_state_dict = load_state_dict(model_file, disable_mmap=disable_mmap)
        for k in adapter_state_dict.keys():
            adapter_state_dict[k] = adapter_state_dict[k].to(dtype=torch_dtype, device=device)
            if "lora" in k:
                lora_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]
            else:
                overwrite_state_dict[k.removeprefix(f"{target_module_name}.")] = adapter_state_dict[k]

        # determine initial quantization config.
        #######################################
        pre_quantized = ("quantization_config" in base_module.config
                         and base_module.config["quantization_config"] is not None)
        if pre_quantized:
            config["quantization_config"] = base_module.config.quantization_config
            hf_quantizer = DiffusersAutoQuantizer.from_config(
                config["quantization_config"], pre_quantized=True
            )

            hf_quantizer.validate_environment(torch_dtype=torch_dtype)
            torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)

            user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value

            # Force-set to `True` for more mem efficiency
            if low_cpu_mem_usage is None:
                low_cpu_mem_usage = True
                logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
            elif not low_cpu_mem_usage:
                raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")

        else:
            hf_quantizer = None

        # Check if `_keep_in_fp32_modules` is not None
        use_keep_in_fp32_modules = model_cls._keep_in_fp32_modules is not None and (
            hf_quantizer is None or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
        )

        if use_keep_in_fp32_modules:
            keep_in_fp32_modules = model_cls._keep_in_fp32_modules
            if not isinstance(keep_in_fp32_modules, list):
                keep_in_fp32_modules = [keep_in_fp32_modules]

            if low_cpu_mem_usage is None:
                low_cpu_mem_usage = True
                logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
            elif not low_cpu_mem_usage:
                raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
        else:
            keep_in_fp32_modules = []

        # append modules in overwrite_state_dict to keep_in_fp32_modules
        for k in overwrite_state_dict.keys():
            module_name = k.rsplit('.', 1)[0]
            if module_name and module_name not in keep_in_fp32_modules:
                keep_in_fp32_modules.append(module_name)

        init_contexts = [no_init_weights()]

        if low_cpu_mem_usage:
            init_contexts.append(accelerate.init_empty_weights())

        with ContextManagers(init_contexts):
            piflow_module = model_cls.from_config(config).eval()

        torch.set_default_dtype(dtype_orig)

        if hf_quantizer is not None:
            hf_quantizer.preprocess_model(
                model=piflow_module, device_map=None, keep_in_fp32_modules=keep_in_fp32_modules
            )

        # 4. Load model weights

        base_state_dict = base_module.state_dict()
        base_state_dict.update(overwrite_state_dict)
        empty_state_dict = piflow_module.state_dict()
        for param_name, param in base_state_dict.items():
            if param_name not in empty_state_dict:
                continue
            if hf_quantizer is not None and (
                    hf_quantizer.check_if_quantized_param(
                        piflow_module, param, param_name, base_state_dict, param_device=device)):
                hf_quantizer.create_quantized_param(
                    piflow_module, param, param_name, device, base_state_dict, dtype=torch_dtype
                )
            else:
                assign_param(piflow_module, param_name, param)

        empty_device_cache()

        if hf_quantizer is not None:
            hf_quantizer.postprocess_model(piflow_module)
            piflow_module.hf_quantizer = hf_quantizer

        if len(lora_state_dict) == 0:
            adapter_name = None
        else:
            if adapter_name is None:
                adapter_name = f"{target_module_name}_piflow"
            piflow_module.load_lora_adapter(
                lora_state_dict, prefix=None, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
        if adapter_name is None:
            logger.warning(
                f"No LoRA weights were found in {pretrained_model_name_or_path}."
            )

        setattr(self, target_module_name, piflow_module)

        return adapter_name

    def policy_rollout(
            self,
            x_t_start: torch.Tensor,  # (B, C, *, H, W)
            sigma_t_start: torch.Tensor,
            sigma_t_end: torch.Tensor,
            total_substeps: int,
            policy,
            **kwargs):
        assert sigma_t_start.numel() == 1 and sigma_t_end.numel() == 1, \
            "Only supports scalar sigma_t_start and sigma_t_end."
        raw_t_start = self.scheduler.unwarp_t(
            sigma_t_start, **kwargs)
        raw_t_end = self.scheduler.unwarp_t(
            sigma_t_end, **kwargs)

        delta_raw_t = raw_t_start - raw_t_end
        num_substeps = (delta_raw_t * total_substeps).round().to(torch.long).clamp(min=1)
        substep_size = delta_raw_t / num_substeps

        raw_t = raw_t_start
        sigma_t = sigma_t_start
        x_t = x_t_start

        for substep_id in range(num_substeps.item()):
            u = policy.pi(x_t, sigma_t)

            raw_t_minus = (raw_t - substep_size).clamp(min=0)
            sigma_t_minus = self.scheduler.warp_t(raw_t_minus, **kwargs)
            x_t_minus = x_t + u * (sigma_t_minus - sigma_t)

            x_t = x_t_minus
            sigma_t = sigma_t_minus
            raw_t = raw_t_minus

        x_t_end = x_t

        return x_t_end