|
import jmespath |
|
import asyncio |
|
import json |
|
from urllib.parse import urlencode |
|
from typing import List, Dict |
|
from httpx import AsyncClient, Response |
|
from loguru import logger as log |
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoModel |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import gradio as gr |
|
|
|
client = AsyncClient( |
|
|
|
http2=True, |
|
headers={ |
|
"Accept-Language": "en-US,en;q=0.9", |
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36", |
|
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", |
|
"Accept-Encoding": "gzip, deflate, br", |
|
"content-type": "application/json" |
|
}, |
|
) |
|
|
|
def parse_comments(response: Response) -> Dict: |
|
try: |
|
data = json.loads(response.text) |
|
except json.JSONDecodeError: |
|
log.error(f"Failed to parse JSON response: {response.text}") |
|
return {"comments": [], "total_comments": 0} |
|
|
|
comments_data = data.get("comments", []) |
|
total_comments = data.get("total", 0) |
|
|
|
if not comments_data: |
|
log.warning(f"No comments found in response: {response.text}") |
|
return {"comments": [], "total_comments": total_comments} |
|
|
|
parsed_comments = [] |
|
for comment in comments_data: |
|
result = jmespath.search( |
|
"""{ |
|
text: text |
|
}""", |
|
comment |
|
) |
|
parsed_comments.append(result) |
|
return {"comments": parsed_comments, "total_comments": total_comments} |
|
|
|
async def scrape_comments(post_id: int, comments_count: int = 20, max_comments: int = None) -> List[Dict]: |
|
|
|
def form_api_url(cursor: int): |
|
base_url = "https://www.tiktok.com/api/comment/list/?" |
|
params = { |
|
"aweme_id": post_id, |
|
'count': comments_count, |
|
'cursor': cursor |
|
} |
|
return base_url + urlencode(params) |
|
|
|
log.info(f"Scraping comments from post ID: {post_id}") |
|
first_page = await client.get(form_api_url(0)) |
|
data = parse_comments(first_page) |
|
comments_data = data["comments"] |
|
total_comments = data["total_comments"] |
|
|
|
if not comments_data: |
|
log.warning(f"No comments found for post ID {post_id}") |
|
return [] |
|
if max_comments and max_comments < total_comments: |
|
total_comments = max_comments |
|
|
|
log.info(f"Scraping comments pagination, remaining {total_comments // comments_count - 1} more pages") |
|
_other_pages = [ |
|
client.get(form_api_url(cursor=cursor)) |
|
for cursor in range(comments_count, total_comments + comments_count, comments_count) |
|
] |
|
|
|
for response in asyncio.as_completed(_other_pages): |
|
response = await response |
|
new_comments = parse_comments(response)["comments"] |
|
comments_data.extend(new_comments) |
|
|
|
|
|
if max_comments and len(comments_data) >= max_comments: |
|
comments_data = comments_data[:max_comments] |
|
break |
|
|
|
log.success(f"Scraped {len(comments_data)} comments from post ID {post_id}") |
|
return comments_data |
|
|
|
class SentimentClassifier(nn.Module): |
|
def __init__(self, n_classes): |
|
super(SentimentClassifier, self).__init__() |
|
self.bert = AutoModel.from_pretrained("vinai/phobert-base") |
|
self.drop = nn.Dropout(p=0.3) |
|
self.fc = nn.Linear(self.bert.config.hidden_size, n_classes) |
|
nn.init.normal_(self.fc.weight, std=0.02) |
|
nn.init.normal_(self.fc.bias, 0) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
last_hidden_state, output = self.bert( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=False |
|
) |
|
|
|
x = self.drop(output) |
|
x = self.fc(x) |
|
return x |
|
|
|
def infer(text, tokenizer, max_len=120): |
|
encoded_review = tokenizer.encode_plus( |
|
text, |
|
max_length=max_len, |
|
truncation=True, |
|
add_special_tokens=True, |
|
padding='max_length', |
|
return_attention_mask=True, |
|
return_token_type_ids=False, |
|
return_tensors='pt', |
|
) |
|
|
|
input_ids = encoded_review['input_ids'].to(device) |
|
attention_mask = encoded_review['attention_mask'].to(device) |
|
|
|
output = model(input_ids, attention_mask) |
|
_, y_pred = torch.max(output, dim=1) |
|
|
|
return class_names[y_pred] |
|
|
|
async def predict_comments(video_id): |
|
comments = await scrape_comments( |
|
post_id=int(video_id), |
|
max_comments=2000, |
|
comments_count=20 |
|
) |
|
predictions = [] |
|
for comment in comments: |
|
text = comment['text'] |
|
probs = infer(text, tokenizer) |
|
predictions.append({'comment': text, 'predictions': probs}) |
|
|
|
|
|
total_comments = len(predictions) |
|
label_counts = [0, 0, 0] |
|
comment_off = [] |
|
comment_hate = [] |
|
for prediction in predictions: |
|
probs = prediction['predictions'] |
|
if probs == 'CLEAN': |
|
label_counts[0] += 1 |
|
elif probs == 'OFFENSIVE': |
|
label_counts[1] += 1 |
|
comment_off.append(prediction['comment']) |
|
else : |
|
label_counts[2] += 1 |
|
comment_hate.append(prediction['comment']) |
|
|
|
label_percentages = [count / total_comments * 100 for count in label_counts] |
|
results = { |
|
'total_comments': total_comments, |
|
'label_percentages': { |
|
'CLEAN': label_percentages[0], |
|
'OFFENSIVE': label_percentages[1], |
|
'HATE': label_percentages[2], |
|
'CMT OFFENSIVE': comment_off, |
|
'CMT HATE': comment_hate, |
|
} |
|
} |
|
|
|
return results |
|
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
|
|
model = SentimentClassifier(n_classes=3) |
|
model.to(device) |
|
model.load_state_dict(torch.load('phobert_fold1.pth')) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") |
|
|
|
class_names = ['CLEAN', 'OFFENSIVE', 'HATE'] |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_comments, |
|
inputs="text", |
|
outputs="json" |
|
) |
|
|
|
iface.launch() |