Alexander Bagus commited on
Commit
27d5a9a
·
1 Parent(s): 5b96bb2
Files changed (1) hide show
  1. app.py +12 -12
app.py CHANGED
@@ -117,26 +117,26 @@ def inference(
117
 
118
  # process image
119
  print("DEBUG: process image")
120
- if edit_dict is None:
121
- print("Error: edit_dict is empty.")
122
  return None
123
 
124
- width, height = edit_dict['background'].size
125
-
126
- print("DEBUG: control_image_torch")
127
  sample_size = [height, width]
128
-
129
- if mask_image is not None:
130
- mask_image = get_image_latent(mask_image, sample_size=sample_size)[:, :1, 0]
131
- else:
132
- mask_image = torch.ones([1, 1, sample_size[0], sample_size[1]]) * 255
133
 
134
- inpaint_image = edit_dict['background']
135
  if inpaint_image is not None:
136
  inpaint_image = get_image_latent(inpaint_image, sample_size=sample_size)[:, :, 0]
137
  else:
138
  inpaint_image = torch.zeros([1, 3, sample_size[0], sample_size[1]])
139
-
 
 
 
 
 
140
 
141
  # generation
142
  if randomize_seed: seed = random.randint(0, MAX_SEED)
 
117
 
118
  # process image
119
  print("DEBUG: process image")
120
+ if edit_dict is None or mask_image is None:
121
+ print("Error: edit_dict or mask_image is empty.")
122
  return None
123
 
124
+ # rescale to prevent OOM
125
+ inpaint_image = edit_dict['background']
126
+ inpaint_image, width, height = image_utils.rescale_image(inpaint_image, 1, 8)
127
  sample_size = [height, width]
 
 
 
 
 
128
 
129
+ print("DEBUG: control_image_torch")
130
  if inpaint_image is not None:
131
  inpaint_image = get_image_latent(inpaint_image, sample_size=sample_size)[:, :, 0]
132
  else:
133
  inpaint_image = torch.zeros([1, 3, sample_size[0], sample_size[1]])
134
+
135
+ if mask_image is not None:
136
+ mask_image, w, h = image_utils.rescale_image(mask_image, 1, 8)
137
+ mask_image = get_image_latent(mask_image, sample_size=sample_size)[:, :1, 0]
138
+ else:
139
+ mask_image = torch.ones([1, 1, sample_size[0], sample_size[1]]) * 255
140
 
141
  # generation
142
  if randomize_seed: seed = random.randint(0, MAX_SEED)