File size: 2,018 Bytes
e6bed4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import requests
import streamlit as st
import time 
from transformers import pipeline
import os
from .utils import query

HF_AUTH_TOKEN = os.getenv('HF_AUTH_TOKEN')
headers = {"Authorization": f"Bearer {HF_AUTH_TOKEN}"}

def write():

	st.markdown("# News Title Generation")
	st.sidebar.header("News Title Generation")
	st.write(
		"""Here, you can generate titles for your text in the news domain using the fine-tuned TURNA title generation models. """
	)

	# Sidebar

	# Taken from https://huggingface.co/spaces/flax-community/spanish-gpt2/blob/main/app.py
	st.sidebar.subheader("Configurable parameters")

	model_name = st.sidebar.selectbox(
		"Model Selector",
		options=[
			"turna_title_generation_tr_news",
			"turna_title_generation_mlsum"
		],
		index=0,
	)
	max_new_tokens = st.sidebar.number_input(
		"Maximum length",
		min_value=0,
		max_value=64,
		value=64,
		help="The maximum length of the sequence to be generated.",
	)

	length_penalty = st.sidebar.number_input(
		"Length penalty",
		value=2.0,
		help=" length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences. ",
	)

	no_repeat_ngram_size = st.sidebar.number_input(
		"No Repeat N-Gram Size",
		min_value=0,
		value=3,
		help="If set to int > 0, all ngrams of that size can only occur once.",
	)


	
	input_text = st.text_area(label='Enter a text: ', height=100, 
			value="Kalp krizi geçirenlerin yaklaşık üçte birinin kısa bir süre önce grip atlattığı düşünülüyor. Peki grip virüsü ne yapıyor da kalp krizine yol açıyor? Karpuz şöyle açıkladı: Grip virüsü kanın yapışkanlığını veya pıhtılaşmasını artırıyor. ")
	url = ("https://api-inference.huggingface.co/models/boun-tabi-LMG/" + model_name.lower())
	params = {"length_penalty": length_penalty, "no_repeat_ngram_size": no_repeat_ngram_size, "max_new_tokens": max_new_tokens }
	if st.button("Generate"):
		with st.spinner('Generating...'):
			output = query(input_text, url, params)
			st.success(output)