Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load the implicit CoT model | |
implicit_cot_model_name = 'yuntian-deng/implicit-cot-math-mistral7b' | |
implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name, torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name) | |
implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
implicit_cot_model.eval() | |
# Constants | |
MAX_RESULT_TOKENS = 10 | |
def predict_answer(question): | |
try: | |
input_text = ' '.join(question.split()).strip() + ' ' + tokenizer.eos_token | |
print (input_text) | |
inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu') | |
implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
input_ids = inputs['input_ids'] | |
#print (input_ids) | |
outputs = implicit_cot_model.generate(input_ids=input_ids, | |
max_new_tokens=MAX_RESULT_TOKENS, | |
do_sample=False) | |
#print (outputs) | |
prediction = tokenizer.decode(outputs[0, input_ids.shape[-1]:], skip_special_tokens=True) | |
except Exception as e: | |
prediction = f'{e}' | |
return prediction | |
demo = gr.Interface( | |
fn=predict_answer, | |
inputs=[ | |
gr.Textbox(label='Question', value='Asumi\'s bookshelf has 120 books. She has 10 books on history, twice that many books on literature, and the rest are science books. How many science books does Asumi have?'), | |
], | |
outputs=[ | |
gr.Textbox(label='Implicit CoT Prediction'), | |
], | |
title='Solving Grade School Math Problems without Intermediate Reasoning Steps', | |
description='This demo showcases Mistral-7B\'s ability to solve grade school math problems without producing intermediate steps, using our stepwise internalization approach linked below.', | |
article=""" | |
- [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460) | |
- [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838) | |
- [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step) | |
- [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036) | |
""", | |
clear_btn=None, | |
submit_btn="Get Answer!", | |
live=False, | |
concurrency_limit=1 | |
) | |
demo.queue(max_size=5).launch() | |