Ffftdtd5dtft commited on
Commit
485b791
1 Parent(s): 0427664

Create akn.py

Browse files
Files changed (1) hide show
  1. akn.py +823 -0
akn.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import torch
4
+ from PIL import Image
5
+ from diffusers import (
6
+ StableDiffusionPipeline,
7
+ StableDiffusionImg2ImgPipeline,
8
+ FluxPipeline,
9
+ DiffusionPipeline,
10
+ DPMSolverMultistepScheduler,
11
+ )
12
+ from transformers import (
13
+ pipeline as transformers_pipeline,
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ GPT2Tokenizer,
17
+ GPT2Model,
18
+ AutoModel
19
+ )
20
+ from audiocraft.models import musicgen
21
+ import gradio as gr
22
+ from huggingface_hub import snapshot_download, HfApi, HfFolder
23
+ import io
24
+ import time
25
+ from tqdm import tqdm
26
+ from google.cloud import storage
27
+ import json
28
+
29
+ hf_token = os.getenv("HF_TOKEN")
30
+ gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
31
+ gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
32
+
33
+ HfFolder.save_token(hf_token)
34
+
35
+ storage_client = storage.Client.from_service_account_info(gcs_credentials)
36
+ bucket = storage_client.bucket(gcs_bucket_name)
37
+
38
+
39
+ def load_object_from_gcs(blob_name):
40
+ blob = bucket.blob(blob_name)
41
+ if blob.exists():
42
+ return pickle.loads(blob.download_as_bytes())
43
+ return None
44
+
45
+
46
+ def save_object_to_gcs(blob_name, obj):
47
+ blob = bucket.blob(blob_name)
48
+ blob.upload_from_string(pickle.dumps(obj))
49
+
50
+
51
+ def get_model_or_download(model_id, blob_name, loader_func):
52
+ model = load_object_from_gcs(blob_name)
53
+ if model:
54
+ return model
55
+ try:
56
+ with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
57
+ model = loader_func(model_id, torch_dtype=torch.float16)
58
+ pbar.update(1)
59
+ save_object_to_gcs(blob_name, model)
60
+ return model
61
+ except Exception as e:
62
+ print(f"Failed to load or save model: {e}")
63
+ return None
64
+
65
+
66
+ def generate_image(prompt):
67
+ blob_name = f"diffusers/generated_image:{prompt}"
68
+ image_bytes = load_object_from_gcs(blob_name)
69
+ if not image_bytes:
70
+ try:
71
+ with tqdm(total=1, desc="Generating image") as pbar:
72
+ image = text_to_image_pipeline(prompt).images[0]
73
+ pbar.update(1)
74
+ buffered = io.BytesIO()
75
+ image.save(buffered, format="JPEG")
76
+ image_bytes = buffered.getvalue()
77
+ save_object_to_gcs(blob_name, image_bytes)
78
+ except Exception as e:
79
+ print(f"Failed to generate image: {e}")
80
+ return None
81
+ return image_bytes
82
+
83
+
84
+ def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
85
+ blob_name = f"diffusers/edited_image:{prompt}:{strength}"
86
+ edited_image_bytes = load_object_from_gcs(blob_name)
87
+ if not edited_image_bytes:
88
+ try:
89
+ image = Image.open(io.BytesIO(image_bytes))
90
+ with tqdm(total=1, desc="Editing image") as pbar:
91
+ edited_image = img2img_pipeline(
92
+ prompt=prompt, image=image, strength=strength
93
+ ).images[0]
94
+ pbar.update(1)
95
+ buffered = io.BytesIO()
96
+ edited_image.save(buffered, format="JPEG")
97
+ edited_image_bytes = buffered.getvalue()
98
+ save_object_to_gcs(blob_name, edited_image_bytes)
99
+ except Exception as e:
100
+ print(f"Failed to edit image: {e}")
101
+ return None
102
+ return edited_image_bytes
103
+
104
+
105
+ def generate_song(prompt, duration=10):
106
+ blob_name = f"music/generated_song:{prompt}:{duration}"
107
+ song_bytes = load_object_from_gcs(blob_name)
108
+ if not song_bytes:
109
+ try:
110
+ with tqdm(total=1, desc="Generating song") as pbar:
111
+ song = music_gen(prompt, duration=duration)
112
+ pbar.update(1)
113
+ song_bytes = song[0].getvalue()
114
+ save_object_to_gcs(blob_name, song_bytes)
115
+ except Exception as e:
116
+ print(f"Failed to generate song: {e}")
117
+ return None
118
+ return song_bytes
119
+
120
+
121
+ def generate_text(prompt):
122
+ blob_name = f"transformers/generated_text:{prompt}"
123
+ text = load_object_from_gcs(blob_name)
124
+ if not text:
125
+ try:
126
+ with tqdm(total=1, desc="Generating text") as pbar:
127
+ text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
128
+ "generated_text"
129
+ ].strip()
130
+ pbar.update(1)
131
+ save_object_to_gcs(blob_name, text)
132
+ except Exception as e:
133
+ print(f"Failed to generate text: {e}")
134
+ return None
135
+ return text
136
+
137
+
138
+ def generate_flux_image(prompt):
139
+ blob_name = f"diffusers/generated_flux_image:{prompt}"
140
+ flux_image_bytes = load_object_from_gcs(blob_name)
141
+ if not flux_image_bytes:
142
+ try:
143
+ with tqdm(total=1, desc="Generating FLUX image") as pbar:
144
+ flux_image = flux_pipeline(
145
+ prompt,
146
+ guidance_scale=0.0,
147
+ num_inference_steps=4,
148
+ max_length=256,
149
+ generator=torch.Generator("cpu").manual_seed(0),
150
+ ).images[0]
151
+ pbar.update(1)
152
+ buffered = io.BytesIO()
153
+ flux_image.save(buffered, format="JPEG")
154
+ flux_image_bytes = buffered.getvalue()
155
+ save_object_to_gcs(blob_name, flux_image_bytes)
156
+ except Exception as e:
157
+ print(f"Failed to generate flux image: {e}")
158
+ return None
159
+ return flux_image_bytes
160
+
161
+
162
+ def generate_code(prompt):
163
+ blob_name = f"transformers/generated_code:{prompt}"
164
+ code = load_object_from_gcs(blob_name)
165
+ if not code:
166
+ try:
167
+ with tqdm(total=1, desc="Generating code") as pbar:
168
+ inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt")
169
+ outputs = starcoder_model.generate(inputs, max_new_tokens=256)
170
+ code = starcoder_tokenizer.decode(outputs[0])
171
+ pbar.update(1)
172
+ save_object_to_gcs(blob_name, code)
173
+ except Exception as e:
174
+ print(f"Failed to generate code: {e}")
175
+ return None
176
+ return code
177
+
178
+
179
+ def test_model_meta_llama():
180
+ blob_name = "transformers/meta_llama_test_response"
181
+ response = load_object_from_gcs(blob_name)
182
+ if not response:
183
+ try:
184
+ messages = [
185
+ {
186
+ "role": "system",
187
+ "content": "You are a pirate chatbot who always responds in pirate speak!",
188
+ },
189
+ {"role": "user", "content": "Who are you?"},
190
+ ]
191
+ with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
192
+ response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
193
+ "generated_text"
194
+ ].strip()
195
+ pbar.update(1)
196
+ save_object_to_gcs(blob_name, response)
197
+ except Exception as e:
198
+ print(f"Failed to test Meta-Llama: {e}")
199
+ return None
200
+ return response
201
+
202
+
203
+ def generate_image_sdxl(prompt):
204
+ blob_name = f"diffusers/generated_image_sdxl:{prompt}"
205
+ image_bytes = load_object_from_gcs(blob_name)
206
+ if not image_bytes:
207
+ try:
208
+ with tqdm(total=1, desc="Generating SDXL image") as pbar:
209
+ image = base(
210
+ prompt=prompt,
211
+ num_inference_steps=40,
212
+ denoising_end=0.8,
213
+ output_type="latent",
214
+ ).images
215
+ image = refiner(
216
+ prompt=prompt,
217
+ num_inference_steps=40,
218
+ denoising_start=0.8,
219
+ image=image,
220
+ ).images[0]
221
+ pbar.update(1)
222
+ buffered = io.BytesIO()
223
+ image.save(buffered, format="JPEG")
224
+ image_bytes = buffered.getvalue()
225
+ save_object_to_gcs(blob_name, image_bytes)
226
+ except Exception as e:
227
+ print(f"Failed to generate SDXL image: {e}")
228
+ return None
229
+ return image_bytes
230
+
231
+
232
+ def generate_musicgen_melody(prompt):
233
+ blob_name = f"music/generated_musicgen_melody:{prompt}"
234
+ song_bytes = load_object_from_gcs(blob_name)
235
+ if not song_bytes:
236
+ try:
237
+ with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
238
+ melody, sr = torchaudio.load("./assets/bach.mp3")
239
+ wav = music_gen_melody.generate_with_chroma(
240
+ [prompt], melody[None].expand(3, -1, -1), sr
241
+ )
242
+ pbar.update(1)
243
+ song_bytes = wav[0].getvalue()
244
+ save_object_to_gcs(blob_name, song_bytes)
245
+ except Exception as e:
246
+ print(f"Failed to generate MusicGen melody: {e}")
247
+ return None
248
+ return song_bytes
249
+
250
+
251
+ def generate_musicgen_large(prompt):
252
+ blob_name = f"music/generated_musicgen_large:{prompt}"
253
+ song_bytes = load_object_from_gcs(blob_name)
254
+ if not song_bytes:
255
+ try:
256
+ with tqdm(total=1, desc="Generating MusicGen large") as pbar:
257
+ wav = music_gen_large.generate([prompt])
258
+ pbar.update(1)
259
+ song_bytes = wav[0].getvalue()
260
+ save_object_to_gcs(blob_name, song_bytes)
261
+ except Exception as e:
262
+ print(f"Failed to generate MusicGen large: {e}")
263
+ return None
264
+ return song_bytes
265
+
266
+
267
+ def transcribe_audio(audio_sample):
268
+ blob_name = f"transformers/transcribed_audio:{hash(audio_sample.tobytes())}"
269
+ text = load_object_from_gcs(blob_name)
270
+ if not text:
271
+ try:
272
+ with tqdm(total=1, desc="Transcribing audio") as pbar:
273
+ text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
274
+ pbar.update(1)
275
+ save_object_to_gcs(blob_name, text)
276
+ except Exception as e:
277
+ print(f"Failed to transcribe audio: {e}")
278
+ return None
279
+ return text
280
+
281
+
282
+ def generate_mistral_instruct(prompt):
283
+ blob_name = f"transformers/generated_mistral_instruct:{prompt}"
284
+ response = load_object_from_gcs(blob_name)
285
+ if not response:
286
+ try:
287
+ conversation = [{"role": "user", "content": prompt}]
288
+ with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
289
+ inputs = mistral_instruct_tokenizer.apply_chat_template(
290
+ conversation,
291
+ tools=tools,
292
+ add_generation_prompt=True,
293
+ return_dict=True,
294
+ return_tensors="pt",
295
+ )
296
+ outputs = mistral_instruct_model.generate(
297
+ **inputs, max_new_tokens=1000
298
+ )
299
+ response = mistral_instruct_tokenizer.decode(
300
+ outputs[0], skip_special_tokens=True
301
+ )
302
+ pbar.update(1)
303
+ save_object_to_gcs(blob_name, response)
304
+ except Exception as e:
305
+ print(f"Failed to generate Mistral Instruct response: {e}")
306
+ return None
307
+ return response
308
+
309
+
310
+ def generate_mistral_nemo(prompt):
311
+ blob_name = f"transformers/generated_mistral_nemo:{prompt}"
312
+ response = load_object_from_gcs(blob_name)
313
+ if not response:
314
+ try:
315
+ conversation = [{"role": "user", "content": prompt}]
316
+ with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
317
+ inputs = mistral_nemo_tokenizer.apply_chat_template(
318
+ conversation,
319
+ tools=tools,
320
+ add_generation_prompt=True,
321
+ return_dict=True,
322
+ return_tensors="pt",
323
+ )
324
+ outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
325
+ response = mistral_nemo_tokenizer.decode(
326
+ outputs[0], skip_special_tokens=True
327
+ )
328
+ pbar.update(1)
329
+ save_object_to_gcs(blob_name, response)
330
+ except Exception as e:
331
+ print(f"Failed to generate Mistral Nemo response: {e}")
332
+ return None
333
+ return response
334
+
335
+
336
+ def generate_gpt2_xl(prompt):
337
+ blob_name = f"transformers/generated_gpt2_xl:{prompt}"
338
+ response = load_object_from_gcs(blob_name)
339
+ if not response:
340
+ try:
341
+ with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
342
+ inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
343
+ outputs = gpt2_xl_model(**inputs)
344
+ response = gpt2_xl_tokenizer.decode(
345
+ outputs[0][0], skip_special_tokens=True
346
+ )
347
+ pbar.update(1)
348
+ save_object_to_gcs(blob_name, response)
349
+ except Exception as e:
350
+ print(f"Failed to generate GPT-2 XL response: {e}")
351
+ return None
352
+ return response
353
+
354
+
355
+ def store_user_question(question):
356
+ blob_name = "user_questions.txt"
357
+ blob = bucket.blob(blob_name)
358
+ if blob.exists():
359
+ blob.download_to_filename("user_questions.txt")
360
+ with open("user_questions.txt", "a") as f:
361
+ f.write(question + "\n")
362
+ blob.upload_from_filename("user_questions.txt")
363
+
364
+
365
+ def retrain_models():
366
+ pass
367
+
368
+
369
+ def generate_text_to_video_ms_1_7b(prompt, num_frames=200):
370
+ blob_name = f"diffusers/text_to_video_ms_1_7b:{prompt}:{num_frames}"
371
+ video_bytes = load_object_from_gcs(blob_name)
372
+ if not video_bytes:
373
+ try:
374
+ with tqdm(total=1, desc="Generating video") as pbar:
375
+ video_frames = text_to_video_ms_1_7b_pipeline(
376
+ prompt, num_inference_steps=25, num_frames=num_frames
377
+ ).frames
378
+ pbar.update(1)
379
+ video_path = export_to_video(video_frames)
380
+ with open(video_path, "rb") as f:
381
+ video_bytes = f.read()
382
+ save_object_to_gcs(blob_name, video_bytes)
383
+ os.remove(video_path)
384
+ except Exception as e:
385
+ print(f"Failed to generate video: {e}")
386
+ return None
387
+ return video_bytes
388
+
389
+
390
+ def generate_text_to_video_ms_1_7b_short(prompt):
391
+ blob_name = f"diffusers/text_to_video_ms_1_7b_short:{prompt}"
392
+ video_bytes = load_object_from_gcs(blob_name)
393
+ if not video_bytes:
394
+ try:
395
+ with tqdm(total=1, desc="Generating short video") as pbar:
396
+ video_frames = text_to_video_ms_1_7b_short_pipeline(
397
+ prompt, num_inference_steps=25
398
+ ).frames
399
+ pbar.update(1)
400
+ video_path = export_to_video(video_frames)
401
+ with open(video_path, "rb") as f:
402
+ video_bytes = f.read()
403
+ save_object_to_gcs(blob_name, video_bytes)
404
+ os.remove(video_path)
405
+ except Exception as e:
406
+ print(f"Failed to generate short video: {e}")
407
+ return None
408
+ return video_bytes
409
+
410
+
411
+ text_to_image_pipeline = get_model_or_download(
412
+ "stabilityai/stable-diffusion-2",
413
+ "diffusers/text_to_image_model",
414
+ StableDiffusionPipeline.from_pretrained,
415
+ )
416
+ img2img_pipeline = get_model_or_download(
417
+ "CompVis/stable-diffusion-v1-4",
418
+ "diffusers/img2img_model",
419
+ StableDiffusionImg2ImgPipeline.from_pretrained,
420
+ )
421
+ flux_pipeline = get_model_or_download(
422
+ "black-forest-labs/FLUX.1-schnell",
423
+ "diffusers/flux_model",
424
+ FluxPipeline.from_pretrained,
425
+ )
426
+ text_gen_pipeline = transformers_pipeline(
427
+ "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
428
+ )
429
+ music_gen = (
430
+ load_object_from_gcs("music/music_gen")
431
+ or musicgen.MusicGen.get_pretrained("melody")
432
+ )
433
+ meta_llama_pipeline = get_model_or_download(
434
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
435
+ "transformers/meta_llama_model",
436
+ transformers_pipeline,
437
+ )
438
+ starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
439
+ starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
440
+
441
+ base = DiffusionPipeline.from_pretrained(
442
+ "stabilityai/stable-diffusion-xl-base-1.0",
443
+ torch_dtype=torch.float16,
444
+ variant="fp16",
445
+ use_safetensors=True,
446
+ )
447
+ refiner = DiffusionPipeline.from_pretrained(
448
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
449
+ text_encoder_2=base.text_encoder_2,
450
+ vae=base.vae,
451
+ torch_dtype=torch.float16,
452
+ use_safetensors=True,
453
+ variant="fp16",
454
+ )
455
+ music_gen_melody = musicgen.MusicGen.get_pretrained("melody")
456
+ music_gen_melody.set_generation_params(duration=8)
457
+ music_gen_large = musicgen.MusicGen.get_pretrained("large")
458
+ music_gen_large.set_generation_params(duration=8)
459
+ whisper_pipeline = transformers_pipeline(
460
+ "automatic-speech-recognition",
461
+ model="openai/whisper-small",
462
+ chunk_length_s=30,
463
+ )
464
+ mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
465
+ "mistralai/Mistral-Large-Instruct-2407",
466
+ torch_dtype=torch.bfloat16,
467
+ device_map="auto",
468
+ )
469
+ mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
470
+ "mistralai/Mistral-Large-Instruct-2407"
471
+ )
472
+ mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
473
+ "mistralai/Mistral-Nemo-Instruct-2407",
474
+ torch_dtype=torch.bfloat16,
475
+ device_map="auto",
476
+ )
477
+ mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
478
+ "mistralai/Mistral-Nemo-Instruct-2407"
479
+ )
480
+ gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
481
+ gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
482
+
483
+ llama_3_groq_70b_tool_use_pipeline = transformers_pipeline(
484
+ "text-generation", model="Groq/Llama-3-Groq-70B-Tool-Use"
485
+ )
486
+ phi_3_5_mini_instruct_model = AutoModelForCausalLM.from_pretrained(
487
+ "microsoft/Phi-3.5-mini-instruct", torch_dtype="auto", trust_remote_code=True
488
+ )
489
+ phi_3_5_mini_instruct_tokenizer = AutoTokenizer.from_pretrained(
490
+ "microsoft/Phi-3.5-mini-instruct"
491
+ )
492
+ phi_3_5_mini_instruct_pipeline = transformers_pipeline(
493
+ "text-generation",
494
+ model=phi_3_5_mini_instruct_model,
495
+ tokenizer=phi_3_5_mini_instruct_tokenizer,
496
+ )
497
+ meta_llama_3_1_8b_pipeline = transformers_pipeline(
498
+ "text-generation",
499
+ model="meta-llama/Meta-Llama-3.1-8B",
500
+ model_kwargs={"torch_dtype": torch.bfloat16},
501
+ )
502
+ meta_llama_3_1_70b_pipeline = transformers_pipeline(
503
+ "text-generation",
504
+ model="meta-llama/Meta-Llama-3.1-70B",
505
+ model_kwargs={"torch_dtype": torch.bfloat16},
506
+ )
507
+ medical_text_summarization_pipeline = transformers_pipeline(
508
+ "summarization", model="your/medical_text_summarization_model"
509
+ )
510
+ bart_large_cnn_summarization_pipeline = transformers_pipeline(
511
+ "summarization", model="facebook/bart-large-cnn"
512
+ )
513
+ flux_1_dev_pipeline = FluxPipeline.from_pretrained(
514
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
515
+ )
516
+ flux_1_dev_pipeline.enable_model_cpu_offload()
517
+ gemma_2_9b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b")
518
+ gemma_2_9b_it_pipeline = transformers_pipeline(
519
+ "text-generation",
520
+ model="google/gemma-2-9b-it",
521
+ model_kwargs={"torch_dtype": torch.bfloat16},
522
+ )
523
+ gemma_2_2b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-2b")
524
+ gemma_2_2b_it_pipeline = transformers_pipeline(
525
+ "text-generation",
526
+ model="google/gemma-2-2b-it",
527
+ model_kwargs={"torch_dtype": torch.bfloat16},
528
+ )
529
+ gemma_2_27b_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
530
+ gemma_2_27b_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-27b")
531
+ gemma_2_27b_it_pipeline = transformers_pipeline(
532
+ "text-generation",
533
+ model="google/gemma-2-27b-it",
534
+ model_kwargs={"torch_dtype": torch.bfloat16},
535
+ )
536
+ text_to_video_ms_1_7b_pipeline = DiffusionPipeline.from_pretrained(
537
+ "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
538
+ )
539
+ text_to_video_ms_1_7b_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
540
+ text_to_video_ms_1_7b_pipeline.scheduler.config
541
+ )
542
+ text_to_video_ms_1_7b_pipeline.enable_model_cpu_offload()
543
+ text_to_video_ms_1_7b_pipeline.enable_vae_slicing()
544
+ text_to_video_ms_1_7b_short_pipeline = DiffusionPipeline.from_pretrained(
545
+ "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
546
+ )
547
+ text_to_video_ms_1_7b_short_pipeline.scheduler = (
548
+ DPMSolverMultistepScheduler.from_config(
549
+ text_to_video_ms_1_7b_short_pipeline.scheduler.config
550
+ )
551
+ )
552
+ text_to_video_ms_1_7b_short_pipeline.enable_model_cpu_offload()
553
+
554
+ tools = []
555
+
556
+ gen_image_tab = gr.Interface(
557
+ fn=generate_image,
558
+ inputs=gr.Textbox(label="Prompt:"),
559
+ outputs=gr.Image(type="pil"),
560
+ title="Generate Image",
561
+ )
562
+ edit_image_tab = gr.Interface(
563
+ fn=edit_image_with_prompt,
564
+ inputs=[
565
+ gr.Image(type="pil", label="Image:"),
566
+ gr.Textbox(label="Prompt:"),
567
+ gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
568
+ ],
569
+ outputs=gr.Image(type="pil"),
570
+ title="Edit Image",
571
+ )
572
+ generate_song_tab = gr.Interface(
573
+ fn=generate_song,
574
+ inputs=[
575
+ gr.Textbox(label="Prompt:"),
576
+ gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
577
+ ],
578
+ outputs=gr.Audio(type="numpy"),
579
+ title="Generate Songs",
580
+ )
581
+ generate_text_tab = gr.Interface(
582
+ fn=generate_text,
583
+ inputs=gr.Textbox(label="Prompt:"),
584
+ outputs=gr.Textbox(label="Generated Text:"),
585
+ title="Generate Text",
586
+ )
587
+ generate_flux_image_tab = gr.Interface(
588
+ fn=generate_flux_image,
589
+ inputs=gr.Textbox(label="Prompt:"),
590
+ outputs=gr.Image(type="pil"),
591
+ title="Generate FLUX Images",
592
+ )
593
+ generate_code_tab = gr.Interface(
594
+ fn=generate_code,
595
+ inputs=gr.Textbox(label="Prompt:"),
596
+ outputs=gr.Textbox(label="Generated Code:"),
597
+ title="Generate Code",
598
+ )
599
+ model_meta_llama_test_tab = gr.Interface(
600
+ fn=test_model_meta_llama,
601
+ inputs=None,
602
+ outputs=gr.Textbox(label="Model Output:"),
603
+ title="Test Meta-Llama",
604
+ )
605
+ generate_image_sdxl_tab = gr.Interface(
606
+ fn=generate_image_sdxl,
607
+ inputs=gr.Textbox(label="Prompt:"),
608
+ outputs=gr.Image(type="pil"),
609
+ title="Generate SDXL Image",
610
+ )
611
+ generate_musicgen_melody_tab = gr.Interface(
612
+ fn=generate_musicgen_melody,
613
+ inputs=gr.Textbox(label="Prompt:"),
614
+ outputs=gr.Audio(type="numpy"),
615
+ title="Generate MusicGen Melody",
616
+ )
617
+ generate_musicgen_large_tab = gr.Interface(
618
+ fn=generate_musicgen_large,
619
+ inputs=gr.Textbox(label="Prompt:"),
620
+ outputs=gr.Audio(type="numpy"),
621
+ title="Generate MusicGen Large",
622
+ )
623
+ transcribe_audio_tab = gr.Interface(
624
+ fn=transcribe_audio,
625
+ inputs=gr.Audio(type="numpy", label="Audio Sample:"),
626
+ outputs=gr.Textbox(label="Transcribed Text:"),
627
+ title="Transcribe Audio",
628
+ )
629
+ generate_mistral_instruct_tab = gr.Interface(
630
+ fn=generate_mistral_instruct,
631
+ inputs=gr.Textbox(label="Prompt:"),
632
+ outputs=gr.Textbox(label="Mistral Instruct Response:"),
633
+ title="Generate Mistral Instruct Response",
634
+ )
635
+ generate_mistral_nemo_tab = gr.Interface(
636
+ fn=generate_mistral_nemo,
637
+ inputs=gr.Textbox(label="Prompt:"),
638
+ outputs=gr.Textbox(label="Mistral Nemo Response:"),
639
+ title="Generate Mistral Nemo Response",
640
+ )
641
+ generate_gpt2_xl_tab = gr.Interface(
642
+ fn=generate_gpt2_xl,
643
+ inputs=gr.Textbox(label="Prompt:"),
644
+ outputs=gr.Textbox(label="GPT-2 XL Response:"),
645
+ title="Generate GPT-2 XL Response",
646
+ )
647
+ answer_question_minicpm_tab = gr.Interface(
648
+ fn=answer_question_minicpm,
649
+ inputs=[
650
+ gr.Image(type="pil", label="Image:"),
651
+ gr.Textbox(label="Question:"),
652
+ ],
653
+ outputs=gr.Textbox(label="MiniCPM Answer:"),
654
+ title="Answer Question with MiniCPM",
655
+ )
656
+ llama_3_groq_70b_tool_use_tab = gr.Interface(
657
+ fn=llama_3_groq_70b_tool_use_pipeline,
658
+ inputs=[gr.Textbox(label="Prompt:")],
659
+ outputs=gr.Textbox(label="Llama 3 Groq 70B Tool Use Response:"),
660
+ title="Llama 3 Groq 70B Tool Use",
661
+ )
662
+ phi_3_5_mini_instruct_tab = gr.Interface(
663
+ fn=phi_3_5_mini_instruct_pipeline,
664
+ inputs=[gr.Textbox(label="Prompt:")],
665
+ outputs=gr.Textbox(label="Phi 3.5 Mini Instruct Response:"),
666
+ title="Phi 3.5 Mini Instruct",
667
+ )
668
+ meta_llama_3_1_8b_tab = gr.Interface(
669
+ fn=meta_llama_3_1_8b_pipeline,
670
+ inputs=[gr.Textbox(label="Prompt:")],
671
+ outputs=gr.Textbox(label="Meta Llama 3.1 8B Response:"),
672
+ title="Meta Llama 3.1 8B",
673
+ )
674
+ meta_llama_3_1_70b_tab = gr.Interface(
675
+ fn=meta_llama_3_1_70b_pipeline,
676
+ inputs=[gr.Textbox(label="Prompt:")],
677
+ outputs=gr.Textbox(label="Meta Llama 3.1 70B Response:"),
678
+ title="Meta Llama 3.1 70B",
679
+ )
680
+ medical_text_summarization_tab = gr.Interface(
681
+ fn=medical_text_summarization_pipeline,
682
+ inputs=[gr.Textbox(label="Medical Document:")],
683
+ outputs=gr.Textbox(label="Medical Text Summarization:"),
684
+ title="Medical Text Summarization",
685
+ )
686
+ bart_large_cnn_summarization_tab = gr.Interface(
687
+ fn=bart_large_cnn_summarization_pipeline,
688
+ inputs=[gr.Textbox(label="Article:")],
689
+ outputs=gr.Textbox(label="Bart Large CNN Summarization:"),
690
+ title="Bart Large CNN Summarization",
691
+ )
692
+ flux_1_dev_tab = gr.Interface(
693
+ fn=flux_1_dev_pipeline,
694
+ inputs=[gr.Textbox(label="Prompt:")],
695
+ outputs=gr.Image(type="pil"),
696
+ title="FLUX 1 Dev",
697
+ )
698
+ gemma_2_9b_tab = gr.Interface(
699
+ fn=gemma_2_9b_pipeline,
700
+ inputs=[gr.Textbox(label="Prompt:")],
701
+ outputs=gr.Textbox(label="Gemma 2 9B Response:"),
702
+ title="Gemma 2 9B",
703
+ )
704
+ gemma_2_9b_it_tab = gr.Interface(
705
+ fn=gemma_2_9b_it_pipeline,
706
+ inputs=[gr.Textbox(label="Prompt:")],
707
+ outputs=gr.Textbox(label="Gemma 2 9B IT Response:"),
708
+ title="Gemma 2 9B IT",
709
+ )
710
+ gemma_2_2b_tab = gr.Interface(
711
+ fn=gemma_2_2b_pipeline,
712
+ inputs=[gr.Textbox(label="Prompt:")],
713
+ outputs=gr.Textbox(label="Gemma 2 2B Response:"),
714
+ title="Gemma 2 2B",
715
+ )
716
+ gemma_2_2b_it_tab = gr.Interface(
717
+ fn=gemma_2_2b_it_pipeline,
718
+ inputs=[gr.Textbox(label="Prompt:")],
719
+ outputs=gr.Textbox(label="Gemma 2 2B IT Response:"),
720
+ title="Gemma 2 2B IT",
721
+ )
722
+
723
+
724
+ def generate_gemma_2_27b(prompt):
725
+ input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt")
726
+ outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32)
727
+ return gemma_2_27b_tokenizer.decode(outputs[0])
728
+
729
+
730
+ gemma_2_27b_tab = gr.Interface(
731
+ fn=generate_gemma_2_27b,
732
+ inputs=[gr.Textbox(label="Prompt:")],
733
+ outputs=gr.Textbox(label="Gemma 2 27B Response:"),
734
+ title="Gemma 2 27B",
735
+ )
736
+ gemma_2_27b_it_tab = gr.Interface(
737
+ fn=gemma_2_27b_it_pipeline,
738
+ inputs=[gr.Textbox(label="Prompt:")],
739
+ outputs=gr.Textbox(label="Gemma 2 27B IT Response:"),
740
+ title="Gemma 2 27B IT",
741
+ )
742
+ text_to_video_ms_1_7b_tab = gr.Interface(
743
+ fn=generate_text_to_video_ms_1_7b,
744
+ inputs=[
745
+ gr.Textbox(label="Prompt:"),
746
+ gr.Slider(50, 200, 200, step=1, label="Number of Frames:"),
747
+ ],
748
+ outputs=gr.Video(),
749
+ title="Text to Video MS 1.7B",
750
+ )
751
+ text_to_video_ms_1_7b_short_tab = gr.Interface(
752
+ fn=generate_text_to_video_ms_1_7b_short,
753
+ inputs=[gr.Textbox(label="Prompt:")],
754
+ outputs=gr.Video(),
755
+ title="Text to Video MS 1.7B Short",
756
+ )
757
+
758
+ app = gr.TabbedInterface(
759
+ [
760
+ gen_image_tab,
761
+ edit_image_tab,
762
+ generate_song_tab,
763
+ generate_text_tab,
764
+ generate_flux_image_tab,
765
+ generate_code_tab,
766
+ model_meta_llama_test_tab,
767
+ generate_image_sdxl_tab,
768
+ generate_musicgen_melody_tab,
769
+ generate_musicgen_large_tab,
770
+ transcribe_audio_tab,
771
+ generate_mistral_instruct_tab,
772
+ generate_mistral_nemo_tab,
773
+ generate_gpt2_xl_tab,
774
+ llama_3_groq_70b_tool_use_tab,
775
+ phi_3_5_mini_instruct_tab,
776
+ meta_llama_3_1_8b_tab,
777
+ meta_llama_3_1_70b_tab,
778
+ medical_text_summarization_tab,
779
+ bart_large_cnn_summarization_tab,
780
+ flux_1_dev_tab,
781
+ gemma_2_9b_tab,
782
+ gemma_2_9b_it_tab,
783
+ gemma_2_2b_tab,
784
+ gemma_2_2b_it_tab,
785
+ gemma_2_27b_tab,
786
+ gemma_2_27b_it_tab,
787
+ text_to_video_ms_1_7b_tab,
788
+ text_to_video_ms_1_7b_short_tab,
789
+ ],
790
+ [
791
+ "Generate Image",
792
+ "Edit Image",
793
+ "Generate Song",
794
+ "Generate Text",
795
+ "Generate FLUX Image",
796
+ "Generate Code",
797
+ "Test Meta-Llama",
798
+ "Generate SDXL Image",
799
+ "Generate MusicGen Melody",
800
+ "Generate MusicGen Large",
801
+ "Transcribe Audio",
802
+ "Generate Mistral Instruct Response",
803
+ "Generate Mistral Nemo Response",
804
+ "Generate GPT-2 XL Response",
805
+ "Llama 3 Groq 70B Tool Use",
806
+ "Phi 3.5 Mini Instruct",
807
+ "Meta Llama 3.1 8B",
808
+ "Meta Llama 3.1 70B",
809
+ "Medical Text Summarization",
810
+ "Bart Large CNN Summarization",
811
+ "FLUX 1 Dev",
812
+ "Gemma 2 9B",
813
+ "Gemma 2 9B IT",
814
+ "Gemma 2 2B",
815
+ "Gemma 2 2B IT",
816
+ "Gemma 2 27B",
817
+ "Gemma 2 27B IT",
818
+ "Text to Video MS 1.7B",
819
+ "Text to Video MS 1.7B Short",
820
+ ],
821
+ )
822
+
823
+ app.launch(share=True)