Spaces:
Runtime error
Runtime error
File size: 3,299 Bytes
7f1450b 327d5b5 01f2f78 327d5b5 01f2f78 327d5b5 7f1450b 327d5b5 7f1450b 327d5b5 7f1450b 327d5b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import pandas as pd
import re
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from llama_cpp import Llama
from loguru import logger # Import the logger from loguru
# Function to read the text file and create Spark DataFrame
def create_spark_dataframe(text):
# Get list of chapter strings
chapter_list = [x for x in re.split('CHAPTER .+', text) if len(x) > 100]
# Create Spark DataFrame
spark = SparkSession.builder.appName("Counting word occurrences from a book, under a microscope.").config("spark.driver.memory", "4g").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
df = spark.createDataFrame(pd.DataFrame({'text': chapter_list, 'chapter': range(1, len(chapter_list) + 1)}))
return df
# Read the "War and Peace" text file and create Spark DataFrame
with open('war_and_peace.txt', 'r') as file:
text = file.read()
df_chapters = create_spark_dataframe(text)
# Define the Llama models
MODEL_Q8_0 = Llama(model_path="llama-2-7b-chat.ggmlv3.q8_0.bin", n_ctx=8192, n_batch=512)
MODEL_Q2_K = Llama(model_path="llama-2-7b-chat.ggmlv3.q2_K.bin", n_ctx=8192, n_batch=512)
# Function to summarize a chapter using the selected model
def llama2_summarize(chapter_text, model_version):
# Choose the model based on the model_version parameter
if model_version == "q8_0":
llm = MODEL_Q8_0
elif model_version == "q2_K":
llm = MODEL_Q2_K
else:
return "Error: Invalid model_version."
# Template for this model version
template = """
[INST] <<SYS>>
You are a helpful, respectful and honest assistant.
Always answer as helpfully as possible, while being safe.
Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content.
Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct.
If you don't know the answer to a question, please don't share false information.
<</SYS>>
{INSERT_PROMPT_HERE} [/INST]
"""
# Create prompt
prompt = 'Summarize the following novel chapter in a single sentence (less than 100 words): ' + chapter_text
prompt = template.replace('INSERT_PROMPT_HERE', prompt)
# Log the input chapter text and model_version
logger.info(f"Input chapter text: {chapter_text}")
logger.info(f"Selected model version: {model_version}")
# Generate summary using the selected model
output = llm(prompt, max_tokens=-1, echo=False, temperature=0.2, top_p=0.1)
summary = output['choices'][0]['text']
# Log the generated summary
logger.info(f"Generated summary: {summary}")
return summary
# Define the Gradio interface
iface = gr.Interface(
fn=llama2_summarize,
inputs=[
gr.inputs.File(label="Upload Text File"),
"text",
], # chapter_text, model_version
outputs="text", # Summary text
live=False,
capture_session=True,
title="Llama2 Chapter Summarizer",
description="Upload the text file or enter the chapter text and model version ('q8_0' or 'q2_K') to get a summarized sentence.",
)
if __name__ == "__main__":
iface.launch(); |