Spaces:
Runtime error
Runtime error
File size: 1,673 Bytes
c3a1897 eb902b3 c3a1897 eb902b3 c3a1897 eb902b3 c3a1897 eb902b3 c3a1897 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from models.segment_models.semgent_anything_model import SegmentAnything
from models.segment_models.semantic_segment_anything_model import SemanticSegment
class RegionSemantic():
def __init__(self, device):
self.device = device
self.init_models()
def init_models(self):
self.segment_model = SegmentAnything(self.device)
self.semantic_segment_model = SemanticSegment(self.device)
def semantic_prompt_gen(self, anns):
"""
fliter too small objects and objects with low stability score
anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
semantic_prompt: "person: [0.0, 0.0, 0.0, 0.0]; ..."
"""
# Sort annotations by area in descending order
sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
# Select the top 10 largest regions
top_10_largest_regions = sorted_annotations[:10]
semantic_prompt = ""
print('\033[1;35m' + '*' * 100 + '\033[0m')
print("\nStep3, Semantic Prompt:")
for region in top_10_largest_regions:
semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
print(semantic_prompt)
print('\033[1;35m' + '*' * 100 + '\033[0m')
return semantic_prompt
def region_semantic(self, img_src):
anns = self.segment_model.generate_mask(img_src)
anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
return self.semantic_prompt_gen(anns_w_class)
def region_semantic_debug(self, img_src):
return "region_semantic_debug" |