Spaces:
Sleeping
Sleeping
Commit
·
7fef6fd
1
Parent(s):
0702aa5
modifided clip.py
Browse files- extract_tools.py +4 -4
- tool_utils/clip_segmentation.py +29 -9
extract_tools.py
CHANGED
|
@@ -187,7 +187,7 @@ def get_all_tools():
|
|
| 187 |
clipseg_tool = Tool(
|
| 188 |
name = 'ClipSegmentation-tool',
|
| 189 |
func = clipsegmentation_mask,
|
| 190 |
-
description="""Use this tool when user ask to
|
| 191 |
The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
|
| 192 |
For example :
|
| 193 |
Query :Provide a segmentation mask of all road car and dog in the image
|
|
@@ -212,10 +212,10 @@ def get_all_tools():
|
|
| 212 |
)
|
| 213 |
|
| 214 |
object_extractor = Tool(
|
| 215 |
-
name = "Object
|
| 216 |
func = object_extraction,
|
| 217 |
-
description = " The Tool is used to
|
| 218 |
-
what are the objects I can view in the image or identify the objects within the image
|
| 219 |
)
|
| 220 |
|
| 221 |
image_parameters_tool = Tool(
|
|
|
|
| 187 |
clipseg_tool = Tool(
|
| 188 |
name = 'ClipSegmentation-tool',
|
| 189 |
func = clipsegmentation_mask,
|
| 190 |
+
description="""Use this tool when user ask to extract the objects from the image .
|
| 191 |
The input to the tool is the path of the image and list of objects for which Segmenation mask is to generated.
|
| 192 |
For example :
|
| 193 |
Query :Provide a segmentation mask of all road car and dog in the image
|
|
|
|
| 212 |
)
|
| 213 |
|
| 214 |
object_extractor = Tool(
|
| 215 |
+
name = "Object description Tool",
|
| 216 |
func = object_extraction,
|
| 217 |
+
description = " The Tool is used to describe the objects within the image . Use this tool if user specifically ask to identify \
|
| 218 |
+
what are the objects I can view in the image or identify the objects within the image. "
|
| 219 |
)
|
| 220 |
|
| 221 |
image_parameters_tool = Tool(
|
tool_utils/clip_segmentation.py
CHANGED
|
@@ -14,11 +14,22 @@ class CLIPSEG:
|
|
| 14 |
self.threshould = threshould
|
| 15 |
self.clip_model.to('cpu')
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
@staticmethod
|
| 18 |
def create_rgb_mask(mask,color=None):
|
| 19 |
-
color = tuple(np.random.choice(range(
|
| 20 |
gray_3_channel = cv2.merge((mask, mask, mask))
|
| 21 |
-
gray_3_channel[mask==255] = color
|
| 22 |
return gray_3_channel.astype(np.uint8)
|
| 23 |
|
| 24 |
def get_segmentation_mask(self,image_path:str,object_prompts:List):
|
|
@@ -41,16 +52,25 @@ class CLIPSEG:
|
|
| 41 |
predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
| 42 |
predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
|
| 43 |
predicted_masks.append(predicted_mask)
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
| 45 |
resize_image = cv2.resize(image,(352,352))
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
|
| 50 |
-
final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
|
| 51 |
try:
|
| 52 |
-
cv2.imwrite('final_mask.png',
|
| 53 |
return 'Segmentation image created : final_mask.png'
|
| 54 |
except Exception as e:
|
| 55 |
logging.error("Error while saving the final mask :",e)
|
| 56 |
-
return "unable to create a mask image "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
self.threshould = threshould
|
| 15 |
self.clip_model.to('cpu')
|
| 16 |
|
| 17 |
+
@ staticmethod
|
| 18 |
+
def create_single_mask(predicted_masks , color = None ):
|
| 19 |
+
|
| 20 |
+
if len(predicted_masks)>0:
|
| 21 |
+
mask_image = np.zeros_like(predicted_masks[0])
|
| 22 |
+
else:
|
| 23 |
+
mask_image = np.zeros(shape=(352,352),dtype=np.unit8)
|
| 24 |
+
for masks in predicted_masks:
|
| 25 |
+
mask_image = np.bitwise_or(mask_image,masks)
|
| 26 |
+
return mask_image
|
| 27 |
+
|
| 28 |
@staticmethod
|
| 29 |
def create_rgb_mask(mask,color=None):
|
| 30 |
+
color = tuple(np.random.choice(range(128,255), size=3))
|
| 31 |
gray_3_channel = cv2.merge((mask, mask, mask))
|
| 32 |
+
gray_3_channel[mask==255] = 255 # for orignial color
|
| 33 |
return gray_3_channel.astype(np.uint8)
|
| 34 |
|
| 35 |
def get_segmentation_mask(self,image_path:str,object_prompts:List):
|
|
|
|
| 52 |
predicted_mask = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
| 53 |
predicted_mask = np.where(predicted_mask>self.threshould, 255,0)
|
| 54 |
predicted_masks.append(predicted_mask)
|
| 55 |
+
|
| 56 |
+
final_mask = self.create_single_mask(predicted_masks)
|
| 57 |
+
rgb_predicted_mask = self.create_rgb_mask(final_mask)
|
| 58 |
+
|
| 59 |
resize_image = cv2.resize(image,(352,352))
|
| 60 |
+
rgb_mask_img = cv2.bitwise_and(resize_image,rgb_predicted_mask )
|
| 61 |
+
|
| 62 |
+
# mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(object_prompts)]
|
| 63 |
+
# cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
|
| 64 |
|
| 65 |
+
# bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
|
| 66 |
+
# final_mask = overlay_masks(resize_image,np.stack(bool_masks,-1),labels=mask_labels,colors=cmap,alpha=0.5,beta=0.7)
|
| 67 |
try:
|
| 68 |
+
cv2.imwrite('final_mask.png',rgb_mask_img)
|
| 69 |
return 'Segmentation image created : final_mask.png'
|
| 70 |
except Exception as e:
|
| 71 |
logging.error("Error while saving the final mask :",e)
|
| 72 |
+
return "unable to create a mask image "
|
| 73 |
+
|
| 74 |
+
if __name__=="__main__":
|
| 75 |
+
clip = CLIPSEG()
|
| 76 |
+
obj = clip.get_segmentation_mask(image_path="../image_store/demo.jpg",object_prompts=['sand','dog'])
|