import re import pandas as pd from tqdm.auto import tqdm from transformers import pipeline from transformers import AutoTokenizer model_checkpoint = "Pclanglais/French-TV-transcript-NER" token_classifier = pipeline( "token-classification", model=model_checkpoint, aggregation_strategy="simple" ) tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) def split_text(text, max_tokens=500): # Split the text by newline characters parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: # Add part to current chunk if current_chunk: temp_chunk = current_chunk + "\n" + part else: temp_chunk = part # Tokenize the temporary chunk num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) # If no newlines were found and still exceeding max_tokens, split further if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 # Ensure split_point does not go out of range if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks complete_data = pd.read_parquet("[file with transcripts]") print(complete_data) classified_list = [] list_prompt = [] list_page = [] list_file = [] list_id = [] text_id = 1 for index, row in complete_data.iterrows(): prompt, current_file = str(row["corrected_text"]), row["identifier"] prompt = re.sub("\n", " ¶ ", prompt) # Tokenize the prompt and check if it exceeds 500 tokens num_tokens = len(tokenizer.tokenize(prompt)) if num_tokens > 500: # Split the prompt into chunks chunks = split_text(prompt, max_tokens=500) for chunk in chunks: list_file.append(current_file) list_prompt.append(chunk) list_id.append(text_id) else: list_file.append(current_file) list_prompt.append(prompt) list_id.append(text_id) text_id = text_id + 1 full_classification = [] batch_size = 4 for out in tqdm(token_classifier(list_prompt, batch_size=batch_size), total=len(list_prompt)/batch_size): full_classification.append(out) id_row = 0 for classification in full_classification: try: df = pd.DataFrame(classification) df["identifier"] = list_file[id_row] df["text_id"] = list_id[id_row] df['word'] = df['word'].replace(' ¶ ', ' \n ', regex=True) print(df) classified_list.append(df) except: pass id_row = id_row + 1 classified_list = pd.concat(classified_list) # Display the DataFrame print(classified_list) classified_list.to_csv("result_transcripts.tsv", sep = "\t")