yizhangliu commited on
Commit
1ef3fca
β€’
1 Parent(s): 26b428d

update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -24
app.py CHANGED
@@ -4,7 +4,7 @@ warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
 
7
- os.system("pip install gradio==3.50.2")
8
 
9
  import gradio as gr
10
  from loguru import logger
@@ -123,6 +123,8 @@ ram_model = None
123
  kosmos_model = None
124
  kosmos_processor = None
125
 
 
 
126
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
127
  args = SLConfig.fromfile(model_config_path)
128
  model = build_model(args)
@@ -593,6 +595,17 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
593
  run_task_time = 0
594
  time_cost_str = ''
595
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
 
 
 
 
 
 
 
 
 
 
596
 
597
  if (task_type == 'Kosmos-2'):
598
  global kosmos_model, kosmos_processor
@@ -605,20 +618,20 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
605
 
606
  kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_input, kosmos_model, kosmos_processor)
607
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
608
- return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
609
 
610
  if (task_type == 'relate anything'):
611
  output_images = relate_anything(input_image['image'], num_relation)
612
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
613
- return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
614
 
615
  text_prompt = text_prompt.strip()
616
  if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
617
  if text_prompt == '':
618
- return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
619
 
620
  if input_image is None:
621
- return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
622
 
623
  file_temp = int(time.time())
624
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
@@ -661,7 +674,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
661
  )
662
  if boxes_filt.size(0) == 0:
663
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
664
- return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
665
  boxes_filt_ori = copy.deepcopy(boxes_filt)
666
 
667
  pred_dict = {
@@ -726,7 +739,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
726
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
727
  if task_type == 'detection' or task_type == 'segment':
728
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
729
- return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
730
  elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
731
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
732
  task_type = 'remove'
@@ -804,11 +817,11 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
804
  output_images.append(image_inpainting)
805
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
806
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
807
- return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
808
  else:
809
  logger.info(f"task_type:{task_type} error!")
810
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
811
- return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
812
 
813
  def change_radio_display(task_type, mask_source_radio):
814
  text_prompt_visible = True
@@ -839,14 +852,14 @@ def change_radio_display(task_type, mask_source_radio):
839
  text_prompt_visible = False
840
  num_relation_visible = True
841
 
842
- return (gr.Textbox.update(visible=text_prompt_visible),
843
- gr.Textbox.update(visible=inpaint_prompt_visible),
844
- gr.Radio.update(visible=mask_source_radio_visible),
845
- gr.Slider.update(visible=num_relation_visible),
846
- gr.Gallery.update(visible=image_gallery_visible),
847
- gr.Radio.update(visible=kosmos_input_visible),
848
- gr.Image.update(visible=kosmos_output_visible),
849
- gr.HighlightedText.update(visible=kosmos_text_output_visible))
850
 
851
  def get_model_device(module):
852
  try:
@@ -883,9 +896,12 @@ def main_gradio(args):
883
  task_types.append("relate anything")
884
  if kosmos_enable:
885
  task_types.append("Kosmos-2")
886
-
887
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
888
- height=512, brush_color='#00FFFF', mask_opacity=0.6)
 
 
 
889
  task_type = gr.Radio(task_types, value="detection",
890
  label='Task type', visible=True)
891
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
@@ -897,7 +913,7 @@ def main_gradio(args):
897
 
898
  kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
899
 
900
- run_button = gr.Button(label="Run", visible=True)
901
  with gr.Accordion("Advanced options", open=False) as advanced_options:
902
  box_threshold = gr.Slider(
903
  label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
@@ -917,7 +933,7 @@ def main_gradio(args):
917
 
918
  with gr.Column():
919
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
920
- ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
921
  time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
922
 
923
  kosmos_output = gr.Image(type="pil", label="result images", visible=False)
@@ -926,9 +942,9 @@ def main_gradio(args):
926
  combine_adjacent=False,
927
  show_legend=True,
928
  visible=False,
929
- ).style(color_map=color_map)
930
  # record which text span (label) is selected
931
- selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
932
 
933
  # record the current `entities`
934
  entity_output = gr.Textbox(visible=False)
 
4
 
5
  import subprocess, io, os, sys, time
6
 
7
+ # os.system("pip install gradio==3.50.2")
8
 
9
  import gradio as gr
10
  from loguru import logger
 
123
  kosmos_model = None
124
  kosmos_processor = None
125
 
126
+ brush_color = "#00FFFF"
127
+
128
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
129
  args = SLConfig.fromfile(model_config_path)
130
  model = build_model(args)
 
595
  run_task_time = 0
596
  time_cost_str = ''
597
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
598
+
599
+ logger.info(f"input_image==={input_image}")
600
+ if 'background' in input_image.keys():
601
+ input_image['image'] = input_image['background']
602
+ if len(input_image['layers']) > 0:
603
+ # input_image['mask'] = input_image['layers'][0] #brush_color
604
+ img_arr = np.array(input_image['layers'][0].convert("L"))
605
+ logger.info(f"img_arr==={img_arr.shape}, {img_arr[760][640]}, {img_arr[0][0]}")
606
+ img_arr = np.where(img_arr > 0, 1, img_arr)
607
+ # img_arr = 1 - img_arr
608
+ input_image['mask'] = Image.fromarray(255*img_arr.astype('uint8'))
609
 
610
  if (task_type == 'Kosmos-2'):
611
  global kosmos_model, kosmos_processor
 
618
 
619
  kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(image_pil, kosmos_input, kosmos_model, kosmos_processor)
620
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
621
+ return None, None, time_cost_str, kosmos_image, gr.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
622
 
623
  if (task_type == 'relate anything'):
624
  output_images = relate_anything(input_image['image'], num_relation)
625
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
626
+ return output_images, gr.update(label='relate images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
627
 
628
  text_prompt = text_prompt.strip()
629
  if not ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_draw):
630
  if text_prompt == '':
631
+ return [], gr.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
632
 
633
  if input_image is None:
634
+ return [], gr.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
635
 
636
  file_temp = int(time.time())
637
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
 
674
  )
675
  if boxes_filt.size(0) == 0:
676
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
677
+ return [], gr.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
678
  boxes_filt_ori = copy.deepcopy(boxes_filt)
679
 
680
  pred_dict = {
 
739
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
740
  if task_type == 'detection' or task_type == 'segment':
741
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
742
+ return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
743
  elif task_type in ['inpainting', 'outpainting'] or task_type == 'remove':
744
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
745
  task_type = 'remove'
 
817
  output_images.append(image_inpainting)
818
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
819
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
820
+ return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
821
  else:
822
  logger.info(f"task_type:{task_type} error!")
823
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
824
+ return output_images, gr.update(label='result images'), time_cost_str, gr.update(visible=(time_cost_str !='')), None, None, None
825
 
826
  def change_radio_display(task_type, mask_source_radio):
827
  text_prompt_visible = True
 
852
  text_prompt_visible = False
853
  num_relation_visible = True
854
 
855
+ return (gr.update(visible=text_prompt_visible),
856
+ gr.update(visible=inpaint_prompt_visible),
857
+ gr.update(visible=mask_source_radio_visible),
858
+ gr.update(visible=num_relation_visible),
859
+ gr.update(visible=image_gallery_visible),
860
+ gr.update(visible=kosmos_input_visible),
861
+ gr.update(visible=kosmos_output_visible),
862
+ gr.update(visible=kosmos_text_output_visible))
863
 
864
  def get_model_device(module):
865
  try:
 
896
  task_types.append("relate anything")
897
  if kosmos_enable:
898
  task_types.append("Kosmos-2")
899
+
900
+ # input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
901
+ # height=512, brush_color='#00FFFF', mask_opacity=0.6)
902
+
903
+ input_image = gr.ImageMask(sources='upload', elem_id="image_upload", type='pil', label="Upload",
904
+ brush=gr.Brush(colors=[brush_color], color_mode="fixed"))
905
  task_type = gr.Radio(task_types, value="detection",
906
  label='Task type', visible=True)
907
  mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
 
913
 
914
  kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
915
 
916
+ run_button = gr.Button(value="Run", visible=True)
917
  with gr.Accordion("Advanced options", open=False) as advanced_options:
918
  box_threshold = gr.Slider(
919
  label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
 
933
 
934
  with gr.Column():
935
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
936
+ ) #.style(preview=True, columns=[5], object_fit="scale-down", height="auto")
937
  time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
938
 
939
  kosmos_output = gr.Image(type="pil", label="result images", visible=False)
 
942
  combine_adjacent=False,
943
  show_legend=True,
944
  visible=False,
945
+ ) # .style(color_map=color_map)
946
  # record which text span (label) is selected
947
+ selected = gr.Number(-1, show_label=False, visible=False)
948
 
949
  # record the current `entities`
950
  entity_output = gr.Textbox(visible=False)