Alexander Bagus commited on
Commit
dd2a1dd
·
1 Parent(s): 6afc27b
Files changed (1) hide show
  1. app.py +19 -19
app.py CHANGED
@@ -35,25 +35,25 @@ has_merged = repo_utils.check_dir_exist(TRANSFORMER_MERGED)
35
  # load transformer
36
  config = OmegaConf.load(TRANSFORMER_CONFIG)
37
 
38
- if not has_merged:
39
- print('load transformer from base')
40
- transformer = ZImageControlTransformer2DModel.from_pretrained(
41
- MODEL_LOCAL,
42
- subfolder="transformer",
43
- transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
44
- ).to("cuda", torch.bfloat16)
45
- print('load state_dict')
46
- state_dict = load_file(TRANSFORMER_LOCAL)
47
- state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
48
- m, u = transformer.load_state_dict(state_dict, strict=False)
49
- print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
50
- transformer.save_pretrained(TRANSFORMER_MERGED)
51
- else:
52
- print('load transformer from merged to bypass calculation')
53
- transformer = ZImageControlTransformer2DModel.from_pretrained(
54
- TRANSFORMER_MERGED,
55
- transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
56
- ).to("cuda", torch.bfloat16)
57
 
58
  print("transformer ready.")
59
 
 
35
  # load transformer
36
  config = OmegaConf.load(TRANSFORMER_CONFIG)
37
 
38
+ # if not has_merged:
39
+ print('load transformer from base')
40
+ transformer = ZImageControlTransformer2DModel.from_pretrained(
41
+ MODEL_LOCAL,
42
+ subfolder="transformer",
43
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
44
+ ).to("cuda", torch.bfloat16)
45
+ print('load state_dict')
46
+ state_dict = load_file(TRANSFORMER_LOCAL)
47
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
48
+ m, u = transformer.load_state_dict(state_dict, strict=False)
49
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
50
+ transformer.save_pretrained(TRANSFORMER_MERGED)
51
+ # else:
52
+ # print('load transformer from merged to bypass calculation')
53
+ # transformer = ZImageControlTransformer2DModel.from_pretrained(
54
+ # TRANSFORMER_MERGED,
55
+ # transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
56
+ # ).to("cuda", torch.bfloat16)
57
 
58
  print("transformer ready.")
59