Spaces:
Runtime error
Runtime error
from frame import Frame | |
from helper import OBJECT_MAP, get_hypernym_path | |
from numpy.polynomial import polynomial | |
from nltk.corpus import wordnet as wn | |
import json | |
import os | |
class NodeFrame: | |
def __init__(self, frame: Frame, p_list: list[float]) -> None: | |
self.frame = frame | |
self.p_list = p_list | |
self.p_total = self.calculate_p_total(p_list) | |
self.p_exactly = self.calculate_p_exactly(p_list) | |
def calculate_p_total(self, p_list: list[float]) -> float: | |
return sum(p_list) | |
def calculate_p_exactly(self, p_list: list[float]) -> list[float]: | |
result = [1] | |
p_list = [[1 - p, p] for p in p_list] | |
for p in p_list: | |
result = polynomial.polymul(result, p) | |
return list(result) | |
def p_of(self, amount: int) -> float: | |
if amount < len(self.p_exactly): | |
return self.p_exactly[amount] | |
else: | |
return self.p_exactly[-1] * (0.1 ** (amount - len(self.p_exactly) + 1)) | |
def serialize(self) -> dict: | |
return { | |
'frame': self.frame.serialize(), | |
'p_list': self.p_list, | |
} | |
class Node: | |
def __init__(self, node_frames: list[NodeFrame]) -> None: | |
self.node_frames = node_frames | |
self.children = {} | |
class Trie: | |
def __init__(self) -> None: | |
self.root = Node([]) | |
def insert(self, node_frame: NodeFrame, path: list[str]) -> None: | |
node = self.root | |
for word in path: | |
if word not in node.children: | |
node.children[word] = Node([]) | |
node = node.children[word] | |
node.node_frames.append(node_frame) | |
def search(self, path: list[str]) -> list[NodeFrame]: | |
node = self.root | |
for word in path: | |
if word not in node.children: | |
return [] | |
node = node.children[word] | |
return self.search_all_children(node) | |
def search_all_children(self, node: Node) -> list[NodeFrame]: | |
result = [] | |
if len(node.node_frames) > 0: | |
result.extend(node.node_frames) | |
for child in node.children.values(): | |
result.extend(self.search_all_children(child)) | |
return result | |
def load_from_dir(self, dir: str) -> None: | |
for path, _, files in os.walk(dir): | |
for file in files: | |
if file.endswith('.json'): | |
data = json.load(open(os.path.join(path, file))) | |
video = file[:-5] | |
for frame_name, frame_data in data.items(): | |
for object, p_list in frame_data.items(): | |
hypernym_path = get_hypernym_path(object) | |
self.insert(NodeFrame(Frame(video=video, frame_name=frame_name), p_list), hypernym_path) | |
def save_to_cache(self, cache_path: str) -> None: | |
json.dump(self.serialize(), open(cache_path, 'w')) | |
def load_from_cache(self, cache_path: str) -> None: | |
self.deserialize(json.load(open(cache_path))) | |
def serialize(self) -> dict: | |
output = {} | |
def dfs(node: Node, path: list[str]) -> None: | |
if len(node.node_frames) > 0: | |
output['/'.join(path)] = [node_frame.serialize() for node_frame in node.node_frames] | |
for word, child in node.children.items(): | |
dfs(child, path + [word]) | |
dfs(self.root, []) | |
return output | |
def deserialize(self, input): | |
for path, node_frames in input.items(): | |
path = path.split('/') | |
for node_frame in node_frames: | |
self.insert(NodeFrame(Frame(id=node_frame['frame']['id']), node_frame['p_list']), path) | |