Pierre Chapuis
commited on
improve API client
Browse files- pyproject.toml +1 -0
- src/app.py +8 -14
- src/fg.py +192 -60
- typings/pillow_heif/__init__.pyi +2 -0
pyproject.toml
CHANGED
@@ -51,3 +51,4 @@ select = [
|
|
51 |
[tool.pyright]
|
52 |
include = ["src"]
|
53 |
exclude = ["**/__pycache__"]
|
|
|
|
51 |
[tool.pyright]
|
52 |
include = ["src"]
|
53 |
exclude = ["**/__pycache__"]
|
54 |
+
strict = ["src/fg.py"]
|
src/app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import dataclasses as dc
|
2 |
import io
|
|
|
3 |
from typing import Any
|
4 |
|
5 |
import gradio as gr
|
@@ -24,6 +25,7 @@ with env.prefixed("ERASER_"):
|
|
24 |
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
|
25 |
|
26 |
|
|
|
27 |
def _ctx() -> EditorAPIContext:
|
28 |
assert API_USER is not None
|
29 |
assert API_PASSWORD is not None
|
@@ -51,13 +53,7 @@ class ProcessParams:
|
|
51 |
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
52 |
with io.BytesIO() as f:
|
53 |
params.image.save(f, format="JPEG")
|
54 |
-
|
55 |
-
response = await client.post(
|
56 |
-
f"{ctx.uri}/state/upload",
|
57 |
-
files={"file": f},
|
58 |
-
headers=ctx.auth_headers,
|
59 |
-
)
|
60 |
-
response.raise_for_status()
|
61 |
st_input = response.json()["state"]
|
62 |
|
63 |
if params.bbox:
|
@@ -74,13 +70,11 @@ async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
|
74 |
st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
|
75 |
st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
)
|
83 |
-
response.raise_for_status()
|
84 |
f = io.BytesIO()
|
85 |
f.write(response.content)
|
86 |
f.seek(0)
|
|
|
1 |
import dataclasses as dc
|
2 |
import io
|
3 |
+
from functools import cache
|
4 |
from typing import Any
|
5 |
|
6 |
import gradio as gr
|
|
|
25 |
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
|
26 |
|
27 |
|
28 |
+
@cache
|
29 |
def _ctx() -> EditorAPIContext:
|
30 |
assert API_USER is not None
|
31 |
assert API_PASSWORD is not None
|
|
|
53 |
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
54 |
with io.BytesIO() as f:
|
55 |
params.image.save(f, format="JPEG")
|
56 |
+
response = await ctx.request("POST", "state/upload", files={"file": f})
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
st_input = response.json()["state"]
|
58 |
|
59 |
if params.bbox:
|
|
|
70 |
st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
|
71 |
st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
|
72 |
|
73 |
+
response = await ctx.request(
|
74 |
+
"GET",
|
75 |
+
f"state/image/{st_erased}",
|
76 |
+
params={"format": "JPEG", "resolution": "DISPLAY"},
|
77 |
+
)
|
|
|
|
|
78 |
f = io.BytesIO()
|
79 |
f.write(response.content)
|
80 |
f.seek(0)
|
src/fg.py
CHANGED
@@ -1,18 +1,46 @@
|
|
1 |
import asyncio
|
2 |
import dataclasses as dc
|
3 |
import json
|
|
|
4 |
from collections import defaultdict
|
5 |
-
from collections.abc import Awaitable, Callable
|
6 |
-
from typing import Any, Literal
|
7 |
|
8 |
import httpx
|
9 |
import httpx_sse
|
|
|
|
|
|
|
10 |
|
11 |
Priority = Literal["low", "standard", "high"]
|
12 |
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
@dc.dataclass(kw_only=True)
|
@@ -23,18 +51,33 @@ class EditorAPIContext:
|
|
23 |
priority: Priority = "standard"
|
24 |
token: str | None = None
|
25 |
verify: bool | str = True
|
26 |
-
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
async def __aenter__(self) -> httpx.AsyncClient:
|
31 |
if self._client:
|
|
|
|
|
32 |
return self._client
|
|
|
33 |
self._client = httpx.AsyncClient(verify=self.verify)
|
|
|
34 |
return self._client
|
35 |
|
36 |
async def __aexit__(self, *args: Any) -> None:
|
37 |
-
if self._client:
|
|
|
|
|
|
|
38 |
await self._client.__aexit__(*args)
|
39 |
self._client = None
|
40 |
|
@@ -49,62 +92,153 @@ class EditorAPIContext:
|
|
49 |
f"{self.uri}/auth/login",
|
50 |
json={"username": self.user, "password": self.password},
|
51 |
)
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
async def sse_loop(self) -> None:
|
56 |
async with self as client:
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
sub_token = response.json()["token"]
|
60 |
url = f"{self.uri}/sub/{sub_token}"
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
async with (
|
62 |
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
|
63 |
-
httpx_sse.aconnect_sse(c, "GET", url) as es,
|
64 |
):
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
79 |
|
80 |
-
async def
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
)
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
self,
|
91 |
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
|
92 |
params: Tin,
|
93 |
) -> Tout:
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
r = tg.create_task(outer_co(params))
|
106 |
-
|
107 |
-
return r.result()
|
108 |
|
109 |
def run_one_sync[Tin, Tout](
|
110 |
self,
|
@@ -116,18 +250,16 @@ class EditorAPIContext:
|
|
116 |
except RuntimeError:
|
117 |
loop = asyncio.new_event_loop()
|
118 |
asyncio.set_event_loop(loop)
|
|
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
123 |
params = {"priority": self.priority} | (params or {})
|
124 |
-
|
125 |
-
response = await client.post(
|
126 |
-
f"{self.uri}/skills/{uri}",
|
127 |
-
json=params,
|
128 |
-
headers=self.auth_headers,
|
129 |
-
)
|
130 |
-
response.raise_for_status()
|
131 |
state_id = response.json()["state"]
|
132 |
-
await self.sse_await(state_id)
|
133 |
return state_id
|
|
|
1 |
import asyncio
|
2 |
import dataclasses as dc
|
3 |
import json
|
4 |
+
import logging
|
5 |
from collections import defaultdict
|
6 |
+
from collections.abc import Awaitable, Callable, Mapping
|
7 |
+
from typing import Any, Literal, cast
|
8 |
|
9 |
import httpx
|
10 |
import httpx_sse
|
11 |
+
from httpx._types import QueryParamTypes, RequestFiles
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
|
15 |
Priority = Literal["low", "standard", "high"]
|
16 |
|
17 |
|
18 |
+
class SSELoopStopped(RuntimeError):
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class Futures[T]:
|
23 |
+
@classmethod
|
24 |
+
def create_future(cls) -> asyncio.Future[T]:
|
25 |
+
return asyncio.get_running_loop().create_future()
|
26 |
+
|
27 |
+
def __init__(self, capacity: int = 256) -> None:
|
28 |
+
self.futures = defaultdict[str, asyncio.Future[T]](self.create_future)
|
29 |
+
self.capacity = capacity
|
30 |
+
|
31 |
+
def cull(self) -> None:
|
32 |
+
while len(self.futures) >= self.capacity:
|
33 |
+
del self.futures[next(iter(self.futures))]
|
34 |
+
|
35 |
+
def __getitem__(self, key: str) -> asyncio.Future[T]:
|
36 |
+
self.cull()
|
37 |
+
return self.futures[key]
|
38 |
+
|
39 |
+
def __delitem__(self, key: str) -> None:
|
40 |
+
try:
|
41 |
+
del self.futures[key]
|
42 |
+
except KeyError:
|
43 |
+
pass
|
44 |
|
45 |
|
46 |
@dc.dataclass(kw_only=True)
|
|
|
51 |
priority: Priority = "standard"
|
52 |
token: str | None = None
|
53 |
verify: bool | str = True
|
54 |
+
default_timeout: float = 60.0
|
55 |
+
logger: logging.Logger = logger
|
56 |
+
max_sse_failures: int = 5
|
57 |
|
58 |
+
_client: httpx.AsyncClient | None = None
|
59 |
+
_client_ctx_depth: int = 0
|
60 |
+
_sse_futures: Futures[dict[str, Any]] = dc.field(default_factory=Futures)
|
61 |
+
_sse_task: asyncio.Task[None] | None = None
|
62 |
+
_sse_failures: int = 0
|
63 |
+
_sse_last_event_id: str = ""
|
64 |
+
_sse_retry_ms: int = 0
|
65 |
|
66 |
async def __aenter__(self) -> httpx.AsyncClient:
|
67 |
if self._client:
|
68 |
+
assert self._client_ctx_depth > 0
|
69 |
+
self._client_ctx_depth += 1
|
70 |
return self._client
|
71 |
+
assert self._client_ctx_depth == 0
|
72 |
self._client = httpx.AsyncClient(verify=self.verify)
|
73 |
+
self._client_ctx_depth = 1
|
74 |
return self._client
|
75 |
|
76 |
async def __aexit__(self, *args: Any) -> None:
|
77 |
+
if (not self._client) or self._client_ctx_depth <= 0:
|
78 |
+
raise RuntimeError("unbalanced __aexit__")
|
79 |
+
self._client_ctx_depth -= 1
|
80 |
+
if self._client_ctx_depth == 0:
|
81 |
await self._client.__aexit__(*args)
|
82 |
self._client = None
|
83 |
|
|
|
92 |
f"{self.uri}/auth/login",
|
93 |
json={"username": self.user, "password": self.password},
|
94 |
)
|
95 |
+
response.raise_for_status()
|
96 |
+
self.logger.debug(f"logged in as {self.user}")
|
97 |
+
self.token = response.json()["token"]
|
98 |
+
|
99 |
+
async def request(
|
100 |
+
self,
|
101 |
+
method: Literal["GET", "POST"],
|
102 |
+
url: str,
|
103 |
+
files: RequestFiles | None = None,
|
104 |
+
params: QueryParamTypes | None = None,
|
105 |
+
json: dict[str, Any] | None = None,
|
106 |
+
headers: Mapping[str, str] | None = None,
|
107 |
+
raise_for_status: bool = True,
|
108 |
+
) -> httpx.Response:
|
109 |
+
async def _q() -> httpx.Response:
|
110 |
+
return await client.request(
|
111 |
+
method,
|
112 |
+
f"{self.uri}/{url}",
|
113 |
+
headers=dict(headers or {}) | self.auth_headers,
|
114 |
+
files=files,
|
115 |
+
params=params,
|
116 |
+
json=json,
|
117 |
+
)
|
118 |
|
|
|
119 |
async with self as client:
|
120 |
+
r = await _q()
|
121 |
+
if r.status_code == 401:
|
122 |
+
self.logger.debug("renewing token")
|
123 |
+
await self.login()
|
124 |
+
r = await _q()
|
125 |
+
|
126 |
+
if raise_for_status:
|
127 |
+
r.raise_for_status()
|
128 |
+
return r
|
129 |
+
|
130 |
+
@classmethod
|
131 |
+
def decode_json(cls, data: str) -> dict[str, Any] | None:
|
132 |
+
try:
|
133 |
+
r = json.loads(data)
|
134 |
+
except json.JSONDecodeError:
|
135 |
+
return None
|
136 |
+
if type(r) is not dict:
|
137 |
+
return None
|
138 |
+
return cast(dict[str, Any], r)
|
139 |
+
|
140 |
+
async def _sse_loop(self) -> None:
|
141 |
+
response = await self.request("POST", "sub-auth")
|
142 |
sub_token = response.json()["token"]
|
143 |
url = f"{self.uri}/sub/{sub_token}"
|
144 |
+
headers = {"Accept": "text/event-stream"}
|
145 |
+
if self._sse_last_event_id:
|
146 |
+
retry_ms = self._sse_retry_ms + 1000 * 2**self._sse_failures
|
147 |
+
self.logger.info(f"resuming SSE from event {self._sse_last_event_id} in {retry_ms} ms")
|
148 |
+
await asyncio.sleep(retry_ms / 1000)
|
149 |
+
headers["Last-Event-ID"] = self._sse_last_event_id
|
150 |
async with (
|
151 |
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
|
152 |
+
httpx_sse.aconnect_sse(c, "GET", url, headers=headers) as es,
|
153 |
):
|
154 |
+
es.response.raise_for_status()
|
155 |
+
self._sse_futures["_sse_loop"].set_result({"status": "ok"})
|
156 |
+
try:
|
157 |
+
async for sse in es.aiter_sse():
|
158 |
+
self._sse_last_event_id = sse.id
|
159 |
+
self._sse_retry_ms = sse.retry or 0
|
160 |
+
jdata = self.decode_json(sse.data)
|
161 |
+
if (jdata is None) or ("state" not in jdata):
|
162 |
+
# Note: when the server restarts we typically get an
|
163 |
+
# empty string here, then the loop exits.
|
164 |
+
self.logger.warning(f"unexpected SSE data: {sse.data}")
|
165 |
+
continue
|
166 |
+
self._sse_futures[jdata["state"]].set_result(jdata)
|
167 |
+
except asyncio.CancelledError:
|
168 |
+
pass
|
169 |
|
170 |
+
async def sse_start(self) -> None:
|
171 |
+
assert self._sse_task is None
|
172 |
+
self._sse_last_event_id = ""
|
173 |
+
self._sse_retry_ms = 0
|
174 |
+
self._sse_task = asyncio.create_task(self._sse_loop())
|
175 |
+
await self.sse_await("_sse_loop")
|
176 |
+
self._sse_failures = 0
|
177 |
+
|
178 |
+
async def sse_recover(self) -> bool:
|
179 |
+
while True:
|
180 |
+
if self._sse_failures > self.max_sse_failures:
|
181 |
+
return False
|
182 |
+
self._sse_task = asyncio.create_task(self._sse_loop())
|
183 |
+
try:
|
184 |
+
await self.sse_await("_sse_loop")
|
185 |
+
return True
|
186 |
+
except SSELoopStopped:
|
187 |
+
pass
|
188 |
+
|
189 |
+
async def sse_stop(self) -> None:
|
190 |
+
assert self._sse_task
|
191 |
+
self._sse_task.cancel()
|
192 |
+
await self._sse_task
|
193 |
+
self._sse_task = None
|
194 |
+
|
195 |
+
async def sse_await(self, state_id: str, timeout: float | None = None) -> None:
|
196 |
+
assert self._sse_task
|
197 |
+
future = self._sse_futures[state_id]
|
198 |
+
|
199 |
+
while True:
|
200 |
+
done, _ = await asyncio.wait(
|
201 |
+
{future, self._sse_task},
|
202 |
+
timeout=timeout or self.default_timeout,
|
203 |
+
return_when=asyncio.FIRST_COMPLETED,
|
204 |
)
|
205 |
+
if not done:
|
206 |
+
raise TimeoutError(f"state {state_id} timed out after {timeout}")
|
207 |
+
if self._sse_task in done:
|
208 |
+
self._sse_failures += 1
|
209 |
+
if state_id != "_sse_loop" and (await self.sse_recover()):
|
210 |
+
self._sse_failures = 0
|
211 |
+
continue
|
212 |
+
exception = self._sse_task.exception()
|
213 |
+
raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
|
214 |
+
break
|
215 |
|
216 |
+
assert done == {future}
|
217 |
+
|
218 |
+
jdata = future.result()
|
219 |
+
del self._sse_futures[state_id]
|
220 |
+
assert jdata["status"] == "ok", f"state {state_id} is {jdata['status']}"
|
221 |
+
|
222 |
+
async def get_meta(self, state_id: str) -> dict[str, Any]:
|
223 |
+
response = await self.request("GET", f"state/meta/{state_id}")
|
224 |
+
return response.json()
|
225 |
+
|
226 |
+
async def _run_one[Tin, Tout](
|
227 |
self,
|
228 |
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
|
229 |
params: Tin,
|
230 |
) -> Tout:
|
231 |
+
# This wraps the coroutine in the SSE loop.
|
232 |
+
# This is mostly useful if you use synchronous Python,
|
233 |
+
# otherwise you can call the functions directly.
|
234 |
+
if not self.token:
|
235 |
+
await self.login()
|
236 |
+
await self.sse_start()
|
237 |
+
try:
|
238 |
+
r = await co(self, params)
|
239 |
+
return r
|
240 |
+
finally:
|
241 |
+
await self.sse_stop()
|
|
|
|
|
|
|
242 |
|
243 |
def run_one_sync[Tin, Tout](
|
244 |
self,
|
|
|
250 |
except RuntimeError:
|
251 |
loop = asyncio.new_event_loop()
|
252 |
asyncio.set_event_loop(loop)
|
253 |
+
return loop.run_until_complete(self._run_one(co, params))
|
254 |
|
255 |
+
async def call_skill(
|
256 |
+
self,
|
257 |
+
uri: str,
|
258 |
+
params: dict[str, Any] | None,
|
259 |
+
timeout: float | None = None,
|
260 |
+
) -> str:
|
261 |
params = {"priority": self.priority} | (params or {})
|
262 |
+
response = await self.request("POST", f"skills/{uri}", json=params)
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
state_id = response.json()["state"]
|
264 |
+
await self.sse_await(state_id, timeout=timeout)
|
265 |
return state_id
|
typings/pillow_heif/__init__.pyi
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
def register_heif_opener() -> None: ...
|
2 |
+
def register_avif_opener() -> None: ...
|