File size: 1,804 Bytes
bfc0ec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Utils for routers."""

import traceback
from typing import Callable, Iterable, Optional

from fastapi import HTTPException, Request, Response
from fastapi.routing import APIRoute

from .auth import UserInfo
from .concepts.db_concept import DISK_CONCEPT_DB, DISK_CONCEPT_MODEL_DB
from .schema import Item, RichData
from .signals.concept_scorer import ConceptSignal


class RouteErrorHandler(APIRoute):
  """Custom APIRoute that handles application errors and exceptions."""

  def get_route_handler(self) -> Callable:
    """Get the route handler."""
    original_route_handler = super().get_route_handler()

    async def custom_route_handler(request: Request) -> Response:
      try:
        return await original_route_handler(request)
      except Exception as ex:
        if isinstance(ex, HTTPException):
          raise ex

        print('Route error:', request.url)
        print(ex)
        print(traceback.format_exc())

        # wrap error into pretty 500 exception
        raise HTTPException(status_code=500, detail=traceback.format_exc()) from ex

    return custom_route_handler


def server_compute_concept(signal: ConceptSignal, examples: Iterable[RichData],
                           user: Optional[UserInfo]) -> list[Optional[Item]]:
  """Compute a concept from the REST endpoints."""
  # TODO(nsthorat): Move this to the setup() method in the concept_scorer.
  concept = DISK_CONCEPT_DB.get(signal.namespace, signal.concept_name, user)
  if not concept:
    raise HTTPException(
      status_code=404, detail=f'Concept "{signal.namespace}/{signal.concept_name}" was not found')
  DISK_CONCEPT_MODEL_DB.sync(
    signal.namespace, signal.concept_name, signal.embedding, user=user, create=True)
  texts = [example or '' for example in examples]
  return list(signal.compute(texts))