lehduong commited on
Commit
39732f4
1 Parent(s): d4ffd77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -8,7 +8,7 @@ import base64
8
  import io
9
  from PIL import Image
10
  from transformers import (
11
- LlavaNextProcessor, LlavaNextForConditionalGeneration,
12
  T5EncoderModel, T5Tokenizer
13
  )
14
  from transformers import (
@@ -53,34 +53,34 @@ TASK2SPECIAL_TOKENS = {
53
  NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
54
 
55
 
56
- class LlavaCaptionProcessor:
57
- def __init__(self):
58
- model_name = "llava-hf/llama3-llava-next-8b-hf"
59
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
61
- self.processor = LlavaNextProcessor.from_pretrained(model_name)
62
- self.model = LlavaNextForConditionalGeneration.from_pretrained(
63
- model_name, torch_dtype=dtype, low_cpu_mem_usage=True
64
- ).to(device)
65
- self.SPECIAL_TOKENS = "assistant\n\n\n"
66
-
67
- def generate_response(self, image: Image.Image, msg: str) -> str:
68
- conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": msg}]}]
69
- with torch.no_grad():
70
- prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
71
- inputs = self.processor(prompt, image, return_tensors="pt").to(self.model.device)
72
- output = self.model.generate(**inputs, max_new_tokens=200)
73
- response = self.processor.decode(output[0], skip_special_tokens=True)
74
- return response.split(msg)[-1].strip()[len(self.SPECIAL_TOKENS):]
75
-
76
- def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
77
- if msg is None:
78
- msg = f"Describe the contents of the photo in 150 words or fewer."
79
- try:
80
- return [self.generate_response(img, msg) for img in images]
81
- except Exception as e:
82
- print(f"Error in process: {str(e)}")
83
- raise
84
 
85
 
86
  class MolmoCaptionProcessor:
@@ -756,7 +756,7 @@ def delete_all_images():
756
 
757
  if __name__ == "__main__":
758
  parser = argparse.ArgumentParser(description='Start the Gradio demo with specified captioner.')
759
- parser.add_argument('--captioner', type=str, choices=['molmo', 'llava', 'disable'], default='disable', help='Captioner to use: molmo, llava, disable.')
760
  args = parser.parse_args()
761
 
762
  # Initialize models with the specified captioner
 
8
  import io
9
  from PIL import Image
10
  from transformers import (
11
+ # LlavaNextProcessor, LlavaNextForConditionalGeneration,
12
  T5EncoderModel, T5Tokenizer
13
  )
14
  from transformers import (
 
53
  NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
54
 
55
 
56
+ # class LlavaCaptionProcessor:
57
+ # def __init__(self):
58
+ # model_name = "llava-hf/llama3-llava-next-8b-hf"
59
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ # dtype = torch.float16 if torch.cuda.is_available() else torch.float32
61
+ # self.processor = LlavaNextProcessor.from_pretrained(model_name)
62
+ # self.model = LlavaNextForConditionalGeneration.from_pretrained(
63
+ # model_name, torch_dtype=dtype, low_cpu_mem_usage=True
64
+ # ).to(device)
65
+ # self.SPECIAL_TOKENS = "assistant\n\n\n"
66
+ #
67
+ # def generate_response(self, image: Image.Image, msg: str) -> str:
68
+ # conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": msg}]}]
69
+ # with torch.no_grad():
70
+ # prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
71
+ # inputs = self.processor(prompt, image, return_tensors="pt").to(self.model.device)
72
+ # output = self.model.generate(**inputs, max_new_tokens=200)
73
+ # response = self.processor.decode(output[0], skip_special_tokens=True)
74
+ # return response.split(msg)[-1].strip()[len(self.SPECIAL_TOKENS):]
75
+ #
76
+ # def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
77
+ # if msg is None:
78
+ # msg = f"Describe the contents of the photo in 150 words or fewer."
79
+ # try:
80
+ # return [self.generate_response(img, msg) for img in images]
81
+ # except Exception as e:
82
+ # print(f"Error in process: {str(e)}")
83
+ # raise
84
 
85
 
86
  class MolmoCaptionProcessor:
 
756
 
757
  if __name__ == "__main__":
758
  parser = argparse.ArgumentParser(description='Start the Gradio demo with specified captioner.')
759
+ parser.add_argument('--captioner', type=str, choices=['molmo', 'llava', 'disable'], default='molmo', help='Captioner to use: molmo, llava, disable.')
760
  args = parser.parse_args()
761
 
762
  # Initialize models with the specified captioner