sayakpaul HF Staff commited on
Commit
01b885e
·
verified ·
1 Parent(s): 9ae5cf2

Update nano_banana.py

Browse files
Files changed (1) hide show
  1. nano_banana.py +18 -14
nano_banana.py CHANGED
@@ -7,18 +7,16 @@ from diffusers.modular_pipelines import (
7
  OutputParam,
8
  )
9
  from PIL import Image
10
- import google.generativeai as genai
11
- import os
12
 
 
13
 
14
  class NanoBanana(ModularPipelineBlocks):
15
  def __init__(self, model_id="gemini-2.5-flash-image-preview"):
16
  super().__init__()
17
- api_key = os.getenv("GEMINI_API_KEY")
18
- if api_key is None:
19
- raise ValueError("Must provide an API key for Gemini through the `GEMINI_API_KEY` env variable.")
20
- genai.configure(api_key=api_key)
21
- self.model = genai.GenerativeModel(model_name=model_id)
22
 
23
  @property
24
  def expected_components(self):
@@ -50,28 +48,34 @@ class NanoBanana(ModularPipelineBlocks):
50
  return [
51
  OutputParam(
52
  "output_image",
53
- type_hint=str,
54
  description="Output image",
55
  ),
56
  OutputParam(
57
  "old_image",
58
- type_hint=str,
59
  description="Old image (if) provided by the user",
60
  )
61
  ]
62
 
63
-
64
  def __call__(self, components, state: PipelineState) -> PipelineState:
65
  block_state = self.get_block_state(state)
66
 
67
  old_image = block_state.image
68
- prompt = block_state.state.prompt
69
  contents = [prompt]
70
  if old_image is not None:
71
- contents.expand(old_image)
72
 
73
- output = self.model.generate_content(contents=contents)
74
- block_state.output_image = output
 
 
 
 
 
 
 
75
 
76
  if old_image is not None:
77
  block_state.old_image = old_image
 
7
  OutputParam,
8
  )
9
  from PIL import Image
10
+ from google import genai
11
+ from io import BytesIO
12
 
13
+ client = genai.Client()
14
 
15
  class NanoBanana(ModularPipelineBlocks):
16
  def __init__(self, model_id="gemini-2.5-flash-image-preview"):
17
  super().__init__()
18
+ # Cannot initialize the client because it throws a pickling error.
19
+ self.model_id = model_id
 
 
 
20
 
21
  @property
22
  def expected_components(self):
 
48
  return [
49
  OutputParam(
50
  "output_image",
51
+ type_hint=Image.Image,
52
  description="Output image",
53
  ),
54
  OutputParam(
55
  "old_image",
56
+ type_hint=Image.Image,
57
  description="Old image (if) provided by the user",
58
  )
59
  ]
60
 
 
61
  def __call__(self, components, state: PipelineState) -> PipelineState:
62
  block_state = self.get_block_state(state)
63
 
64
  old_image = block_state.image
65
+ prompt = block_state.prompt
66
  contents = [prompt]
67
  if old_image is not None:
68
+ contents.append(old_image)
69
 
70
+ response = client.models.generate_content(
71
+ model=self.model_id, contents=contents
72
+
73
+ )
74
+ for part in response.candidates[0].content.parts:
75
+ if part.text is not None:
76
+ continue
77
+ elif part.inline_data is not None:
78
+ block_state.output_image = Image.open(BytesIO(part.inline_data.data))
79
 
80
  if old_image is not None:
81
  block_state.old_image = old_image