Pierre Chapuis commited on
Commit
7c9213f
·
unverified ·
1 Parent(s): eea5fcc

improve API client

Browse files
Files changed (4) hide show
  1. pyproject.toml +1 -0
  2. src/app.py +8 -14
  3. src/fg.py +192 -60
  4. typings/pillow_heif/__init__.pyi +2 -0
pyproject.toml CHANGED
@@ -51,3 +51,4 @@ select = [
51
  [tool.pyright]
52
  include = ["src"]
53
  exclude = ["**/__pycache__"]
 
 
51
  [tool.pyright]
52
  include = ["src"]
53
  exclude = ["**/__pycache__"]
54
+ strict = ["src/fg.py"]
src/app.py CHANGED
@@ -1,5 +1,6 @@
1
  import dataclasses as dc
2
  import io
 
3
  from typing import Any
4
 
5
  import gradio as gr
@@ -24,6 +25,7 @@ with env.prefixed("ERASER_"):
24
  CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
25
 
26
 
 
27
  def _ctx() -> EditorAPIContext:
28
  assert API_USER is not None
29
  assert API_PASSWORD is not None
@@ -51,13 +53,7 @@ class ProcessParams:
51
  async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
52
  with io.BytesIO() as f:
53
  params.image.save(f, format="JPEG")
54
- async with ctx as client:
55
- response = await client.post(
56
- f"{ctx.uri}/state/upload",
57
- files={"file": f},
58
- headers=ctx.auth_headers,
59
- )
60
- response.raise_for_status()
61
  st_input = response.json()["state"]
62
 
63
  if params.bbox:
@@ -74,13 +70,11 @@ async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
74
  st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
75
  st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
76
 
77
- async with ctx as client:
78
- response = await client.get(
79
- f"{ctx.uri}/state/image/{st_erased}",
80
- params={"format": "JPEG", "resolution": "DISPLAY"},
81
- headers=ctx.auth_headers,
82
- )
83
- response.raise_for_status()
84
  f = io.BytesIO()
85
  f.write(response.content)
86
  f.seek(0)
 
1
  import dataclasses as dc
2
  import io
3
+ from functools import cache
4
  from typing import Any
5
 
6
  import gradio as gr
 
25
  CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
26
 
27
 
28
+ @cache
29
  def _ctx() -> EditorAPIContext:
30
  assert API_USER is not None
31
  assert API_PASSWORD is not None
 
53
  async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
54
  with io.BytesIO() as f:
55
  params.image.save(f, format="JPEG")
56
+ response = await ctx.request("POST", "state/upload", files={"file": f})
 
 
 
 
 
 
57
  st_input = response.json()["state"]
58
 
59
  if params.bbox:
 
70
  st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
71
  st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
72
 
73
+ response = await ctx.request(
74
+ "GET",
75
+ f"state/image/{st_erased}",
76
+ params={"format": "JPEG", "resolution": "DISPLAY"},
77
+ )
 
 
78
  f = io.BytesIO()
79
  f.write(response.content)
80
  f.seek(0)
src/fg.py CHANGED
@@ -1,18 +1,46 @@
1
  import asyncio
2
  import dataclasses as dc
3
  import json
 
4
  from collections import defaultdict
5
- from collections.abc import Awaitable, Callable
6
- from typing import Any, Literal
7
 
8
  import httpx
9
  import httpx_sse
 
 
 
10
 
11
  Priority = Literal["low", "standard", "high"]
12
 
13
 
14
- def _new_future() -> asyncio.Future[Any]:
15
- return asyncio.get_running_loop().create_future()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  @dc.dataclass(kw_only=True)
@@ -23,18 +51,33 @@ class EditorAPIContext:
23
  priority: Priority = "standard"
24
  token: str | None = None
25
  verify: bool | str = True
26
- _client: httpx.AsyncClient | None = None
 
 
27
 
28
- sse_futures: dict[str, asyncio.Future[dict[str, Any]]] = dc.field(default_factory=lambda: defaultdict(_new_future))
 
 
 
 
 
 
29
 
30
  async def __aenter__(self) -> httpx.AsyncClient:
31
  if self._client:
 
 
32
  return self._client
 
33
  self._client = httpx.AsyncClient(verify=self.verify)
 
34
  return self._client
35
 
36
  async def __aexit__(self, *args: Any) -> None:
37
- if self._client:
 
 
 
38
  await self._client.__aexit__(*args)
39
  self._client = None
40
 
@@ -49,62 +92,153 @@ class EditorAPIContext:
49
  f"{self.uri}/auth/login",
50
  json={"username": self.user, "password": self.password},
51
  )
52
- response.raise_for_status()
53
- self.token = response.json()["token"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- async def sse_loop(self) -> None:
56
  async with self as client:
57
- response = await client.post(f"{self.uri}/sub-auth", headers=self.auth_headers)
58
- response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  sub_token = response.json()["token"]
60
  url = f"{self.uri}/sub/{sub_token}"
 
 
 
 
 
 
61
  async with (
62
  httpx.AsyncClient(timeout=None, verify=self.verify) as c,
63
- httpx_sse.aconnect_sse(c, "GET", url) as es,
64
  ):
65
- future = self.sse_futures["_sse_loop"]
66
- future.set_result({"status": "ok"})
67
- async for sse in es.aiter_sse():
68
- jdata = json.loads(sse.data)
69
- future = self.sse_futures[jdata["state"]]
70
- future.set_result(jdata)
71
-
72
- async def sse_await(self, state_id: str, timeout: float = 60.0) -> None:
73
- future = self.sse_futures[state_id]
74
- jdata = await asyncio.wait_for(future, timeout=timeout)
75
- if jdata["status"] != "ok":
76
- print("ERROR", jdata)
77
- assert jdata["status"] == "ok"
78
- del self.sse_futures[state_id]
 
79
 
80
- async def get_meta(self, state_id: str) -> dict[str, Any]:
81
- async with self as client:
82
- response = await client.get(
83
- f"{self.uri}/state/meta/{state_id}",
84
- headers=self.auth_headers,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
- response.raise_for_status()
87
- return response.json()
 
 
 
 
 
 
 
 
88
 
89
- async def run_one[Tin, Tout](
 
 
 
 
 
 
 
 
 
 
90
  self,
91
  co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
92
  params: Tin,
93
  ) -> Tout:
94
- await self.login()
95
- async with asyncio.TaskGroup() as tg:
96
- sse_task = tg.create_task(self.sse_loop())
97
-
98
- async def outer_co(params: Tin) -> Tout:
99
- # _sse_loop is a fake event to wait until the SSE loop is properly setup.
100
- await self.sse_await("_sse_loop")
101
- r = await co(self, params)
102
- sse_task.cancel()
103
- return r
104
-
105
- r = tg.create_task(outer_co(params))
106
-
107
- return r.result()
108
 
109
  def run_one_sync[Tin, Tout](
110
  self,
@@ -116,18 +250,16 @@ class EditorAPIContext:
116
  except RuntimeError:
117
  loop = asyncio.new_event_loop()
118
  asyncio.set_event_loop(loop)
 
119
 
120
- return loop.run_until_complete(self.run_one(co, params))
121
-
122
- async def call_skill(self, uri: str, params: dict[str, Any] | None) -> str:
 
 
 
123
  params = {"priority": self.priority} | (params or {})
124
- async with self as client:
125
- response = await client.post(
126
- f"{self.uri}/skills/{uri}",
127
- json=params,
128
- headers=self.auth_headers,
129
- )
130
- response.raise_for_status()
131
  state_id = response.json()["state"]
132
- await self.sse_await(state_id)
133
  return state_id
 
1
  import asyncio
2
  import dataclasses as dc
3
  import json
4
+ import logging
5
  from collections import defaultdict
6
+ from collections.abc import Awaitable, Callable, Mapping
7
+ from typing import Any, Literal, cast
8
 
9
  import httpx
10
  import httpx_sse
11
+ from httpx._types import QueryParamTypes, RequestFiles
12
+
13
+ logger = logging.getLogger(__name__)
14
 
15
  Priority = Literal["low", "standard", "high"]
16
 
17
 
18
+ class SSELoopStopped(RuntimeError):
19
+ pass
20
+
21
+
22
+ class Futures[T]:
23
+ @classmethod
24
+ def create_future(cls) -> asyncio.Future[T]:
25
+ return asyncio.get_running_loop().create_future()
26
+
27
+ def __init__(self, capacity: int = 256) -> None:
28
+ self.futures = defaultdict[str, asyncio.Future[T]](self.create_future)
29
+ self.capacity = capacity
30
+
31
+ def cull(self) -> None:
32
+ while len(self.futures) >= self.capacity:
33
+ del self.futures[next(iter(self.futures))]
34
+
35
+ def __getitem__(self, key: str) -> asyncio.Future[T]:
36
+ self.cull()
37
+ return self.futures[key]
38
+
39
+ def __delitem__(self, key: str) -> None:
40
+ try:
41
+ del self.futures[key]
42
+ except KeyError:
43
+ pass
44
 
45
 
46
  @dc.dataclass(kw_only=True)
 
51
  priority: Priority = "standard"
52
  token: str | None = None
53
  verify: bool | str = True
54
+ default_timeout: float = 60.0
55
+ logger: logging.Logger = logger
56
+ max_sse_failures: int = 5
57
 
58
+ _client: httpx.AsyncClient | None = None
59
+ _client_ctx_depth: int = 0
60
+ _sse_futures: Futures[dict[str, Any]] = dc.field(default_factory=Futures)
61
+ _sse_task: asyncio.Task[None] | None = None
62
+ _sse_failures: int = 0
63
+ _sse_last_event_id: str = ""
64
+ _sse_retry_ms: int = 0
65
 
66
  async def __aenter__(self) -> httpx.AsyncClient:
67
  if self._client:
68
+ assert self._client_ctx_depth > 0
69
+ self._client_ctx_depth += 1
70
  return self._client
71
+ assert self._client_ctx_depth == 0
72
  self._client = httpx.AsyncClient(verify=self.verify)
73
+ self._client_ctx_depth = 1
74
  return self._client
75
 
76
  async def __aexit__(self, *args: Any) -> None:
77
+ if (not self._client) or self._client_ctx_depth <= 0:
78
+ raise RuntimeError("unbalanced __aexit__")
79
+ self._client_ctx_depth -= 1
80
+ if self._client_ctx_depth == 0:
81
  await self._client.__aexit__(*args)
82
  self._client = None
83
 
 
92
  f"{self.uri}/auth/login",
93
  json={"username": self.user, "password": self.password},
94
  )
95
+ response.raise_for_status()
96
+ self.logger.debug(f"logged in as {self.user}")
97
+ self.token = response.json()["token"]
98
+
99
+ async def request(
100
+ self,
101
+ method: Literal["GET", "POST"],
102
+ url: str,
103
+ files: RequestFiles | None = None,
104
+ params: QueryParamTypes | None = None,
105
+ json: dict[str, Any] | None = None,
106
+ headers: Mapping[str, str] | None = None,
107
+ raise_for_status: bool = True,
108
+ ) -> httpx.Response:
109
+ async def _q() -> httpx.Response:
110
+ return await client.request(
111
+ method,
112
+ f"{self.uri}/{url}",
113
+ headers=dict(headers or {}) | self.auth_headers,
114
+ files=files,
115
+ params=params,
116
+ json=json,
117
+ )
118
 
 
119
  async with self as client:
120
+ r = await _q()
121
+ if r.status_code == 401:
122
+ self.logger.debug("renewing token")
123
+ await self.login()
124
+ r = await _q()
125
+
126
+ if raise_for_status:
127
+ r.raise_for_status()
128
+ return r
129
+
130
+ @classmethod
131
+ def decode_json(cls, data: str) -> dict[str, Any] | None:
132
+ try:
133
+ r = json.loads(data)
134
+ except json.JSONDecodeError:
135
+ return None
136
+ if type(r) is not dict:
137
+ return None
138
+ return cast(dict[str, Any], r)
139
+
140
+ async def _sse_loop(self) -> None:
141
+ response = await self.request("POST", "sub-auth")
142
  sub_token = response.json()["token"]
143
  url = f"{self.uri}/sub/{sub_token}"
144
+ headers = {"Accept": "text/event-stream"}
145
+ if self._sse_last_event_id:
146
+ retry_ms = self._sse_retry_ms + 1000 * 2**self._sse_failures
147
+ self.logger.info(f"resuming SSE from event {self._sse_last_event_id} in {retry_ms} ms")
148
+ await asyncio.sleep(retry_ms / 1000)
149
+ headers["Last-Event-ID"] = self._sse_last_event_id
150
  async with (
151
  httpx.AsyncClient(timeout=None, verify=self.verify) as c,
152
+ httpx_sse.aconnect_sse(c, "GET", url, headers=headers) as es,
153
  ):
154
+ es.response.raise_for_status()
155
+ self._sse_futures["_sse_loop"].set_result({"status": "ok"})
156
+ try:
157
+ async for sse in es.aiter_sse():
158
+ self._sse_last_event_id = sse.id
159
+ self._sse_retry_ms = sse.retry or 0
160
+ jdata = self.decode_json(sse.data)
161
+ if (jdata is None) or ("state" not in jdata):
162
+ # Note: when the server restarts we typically get an
163
+ # empty string here, then the loop exits.
164
+ self.logger.warning(f"unexpected SSE data: {sse.data}")
165
+ continue
166
+ self._sse_futures[jdata["state"]].set_result(jdata)
167
+ except asyncio.CancelledError:
168
+ pass
169
 
170
+ async def sse_start(self) -> None:
171
+ assert self._sse_task is None
172
+ self._sse_last_event_id = ""
173
+ self._sse_retry_ms = 0
174
+ self._sse_task = asyncio.create_task(self._sse_loop())
175
+ await self.sse_await("_sse_loop")
176
+ self._sse_failures = 0
177
+
178
+ async def sse_recover(self) -> bool:
179
+ while True:
180
+ if self._sse_failures > self.max_sse_failures:
181
+ return False
182
+ self._sse_task = asyncio.create_task(self._sse_loop())
183
+ try:
184
+ await self.sse_await("_sse_loop")
185
+ return True
186
+ except SSELoopStopped:
187
+ pass
188
+
189
+ async def sse_stop(self) -> None:
190
+ assert self._sse_task
191
+ self._sse_task.cancel()
192
+ await self._sse_task
193
+ self._sse_task = None
194
+
195
+ async def sse_await(self, state_id: str, timeout: float | None = None) -> None:
196
+ assert self._sse_task
197
+ future = self._sse_futures[state_id]
198
+
199
+ while True:
200
+ done, _ = await asyncio.wait(
201
+ {future, self._sse_task},
202
+ timeout=timeout or self.default_timeout,
203
+ return_when=asyncio.FIRST_COMPLETED,
204
  )
205
+ if not done:
206
+ raise TimeoutError(f"state {state_id} timed out after {timeout}")
207
+ if self._sse_task in done:
208
+ self._sse_failures += 1
209
+ if state_id != "_sse_loop" and (await self.sse_recover()):
210
+ self._sse_failures = 0
211
+ continue
212
+ exception = self._sse_task.exception()
213
+ raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
214
+ break
215
 
216
+ assert done == {future}
217
+
218
+ jdata = future.result()
219
+ del self._sse_futures[state_id]
220
+ assert jdata["status"] == "ok", f"state {state_id} is {jdata['status']}"
221
+
222
+ async def get_meta(self, state_id: str) -> dict[str, Any]:
223
+ response = await self.request("GET", f"state/meta/{state_id}")
224
+ return response.json()
225
+
226
+ async def _run_one[Tin, Tout](
227
  self,
228
  co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
229
  params: Tin,
230
  ) -> Tout:
231
+ # This wraps the coroutine in the SSE loop.
232
+ # This is mostly useful if you use synchronous Python,
233
+ # otherwise you can call the functions directly.
234
+ if not self.token:
235
+ await self.login()
236
+ await self.sse_start()
237
+ try:
238
+ r = await co(self, params)
239
+ return r
240
+ finally:
241
+ await self.sse_stop()
 
 
 
242
 
243
  def run_one_sync[Tin, Tout](
244
  self,
 
250
  except RuntimeError:
251
  loop = asyncio.new_event_loop()
252
  asyncio.set_event_loop(loop)
253
+ return loop.run_until_complete(self._run_one(co, params))
254
 
255
+ async def call_skill(
256
+ self,
257
+ uri: str,
258
+ params: dict[str, Any] | None,
259
+ timeout: float | None = None,
260
+ ) -> str:
261
  params = {"priority": self.priority} | (params or {})
262
+ response = await self.request("POST", f"skills/{uri}", json=params)
 
 
 
 
 
 
263
  state_id = response.json()["state"]
264
+ await self.sse_await(state_id, timeout=timeout)
265
  return state_id
typings/pillow_heif/__init__.pyi ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ def register_heif_opener() -> None: ...
2
+ def register_avif_opener() -> None: ...