CLIPictionary / sentence.py
johko's picture
Duplicate from YoannLemesle/CLIPictionary
b222ec5
from model import encoder_text
import torch, clip, random
import numpy as np
device = torch.device("cpu")
from words import words
########## SENTENCE PART #######################################################
voyelles = ["a","e","i","o","u"]
links = list(words.keys())[1:]
def link_text(part,nextWord):
### Check if we need to write "... a", "... an", "..."
if (len(part["link"]) > 0) and (part["link"][-1] == "a"):
voyelleStart = (nextWord[0] in voyelles)
plural = (nextWord[-1] == "s" and nextWord[-2] != "s") or (nextWord in ["nothing","hair","vampire teeth","something"])
else:
voyelleStart, plural = False, False
return (part["link"][:-2] if plural else part["link"] + ("n" if voyelleStart else ""))
def part_text(part):
l = link_text(part,part["word"])
return l + (" " if len(l)>0 else "") + part["word"]
def compute_embeddings(part,var_dict,prefix,batch_size=64):
target = part["word"]
possibleWords = list(set(words[part["link"]]) - set([target]+var_dict["found_words"]))
if len(possibleWords) > (batch_size-1): possibleWords = np.random.choice(list(possibleWords),batch_size-1,replace=False).tolist()
possibleWords.append(target)
### Compute all classes & embeddings for current sentence part
part["classes"] = [prefix + link_text(part,w) + (" " if len(link_text(part,w))>0 else "") + w for w in possibleWords]
with torch.no_grad():
embeddings = encoder_text(clip.tokenize(part["classes"]).to(device))
embeddings /= embeddings.norm(dim=-1, keepdim=True)
part["embeddings"] = embeddings.tolist()
########## SENTENCE ############################################################
def iniSentence(var_dict,input="",first_game=False):
var_dict["found_words"] = []
var_dict["parts"] = []
var_dict["step"] = 0
prefix = ""
N = (2 if var_dict["difficulty"] == 1 else 1)
if first_game:
link = "a drawing of a"
part = {"link":link,"word":"cat","classes":[],"embeddings":[]}
var_dict["parts"].append(part)
compute_embeddings(part, var_dict, prefix)
prefix += part_text(part) + " "
link = "with a"
part = {"link":link,"word":"face","classes":[],"embeddings":[]}
var_dict["parts"].append(part)
compute_embeddings(part, var_dict, prefix)
prefix += part_text(part) + " "
else:
##### Generating Random Sentence
link = "a drawing of a"
part = {"link":link,"word":np.random.choice(words[link]),"classes":[],"embeddings":[]}
var_dict["parts"].append(part)
compute_embeddings(part, var_dict, prefix)
prefix += part_text(part) + " "
for i in range(N-1):
link = np.random.choice(links)
part = {"link":link,"word":np.random.choice(words[link][1:]),"classes":[],"embeddings":[]}
var_dict["parts"].append(part)
compute_embeddings(part, var_dict, prefix)
prefix += part_text(part) + " "
var_dict["target_sentence"] = prefix[:-1] # Target sentence is prefix without the last space
setState(var_dict)
return var_dict["target_sentence"]
def prevState(var_dict):
if len(var_dict["prev_steps"]) > 0: var_dict["step"] = var_dict["prev_steps"].pop(-1)
else: var_dict["step"] = 0
var_dict["revertedState"] = True
setState(var_dict)
def setState(var_dict):
var_dict["found_words"] = var_dict["found_words"][:var_dict["step"]]
var_dict["guessed_sentence"] = ""
for i in range(var_dict["step"]):
var_dict["guessed_sentence"] += part_text(var_dict["parts"][i]) + " "
def updateState(var_dict, preds):
if not var_dict["revertedState"]: var_dict["prev_steps"].append(var_dict["step"])
else: var_dict["revertedState"] = False
### Check if the current part has been guessed
part = var_dict["parts"][var_dict["step"]]
idx_of_nothing = -1
if ("nothing" in preds[0]): idx_of_nothing = 0
elif ("nothing" in preds[1]): idx_of_nothing = 1
elif ("nothing" in preds[2]): idx_of_nothing = 2
idx_of_guess = -1
if (part["classes"][-1] == preds[0]): idx_of_guess = 0
elif (part["classes"][-1] == preds[1]): idx_of_guess = 1
elif (part["classes"][-1] == preds[2]): idx_of_guess = 2
if not var_dict["win"] and (idx_of_guess > idx_of_nothing):
var_dict["step"] += 1
var_dict["found_words"].append(part["word"])
var_dict["win"] = var_dict["step"] == len(var_dict["parts"])
setState(var_dict)
if var_dict["win"]: return 1
else: return 0
elif not var_dict["win"]: return -1
else: return 1