|
import hashlib |
|
import os |
|
from typing import Any, Dict |
|
|
|
import pandas as pd |
|
|
|
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains |
|
from llm_studio.src.datasets.text_utils import get_tokenizer |
|
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels |
|
from llm_studio.src.utils.plot_utils import ( |
|
PlotData, |
|
format_for_markdown_visualization, |
|
list_to_markdown_representation, |
|
) |
|
|
|
|
|
class Plots: |
|
@classmethod |
|
def plot_batch(cls, batch, cfg) -> PlotData: |
|
tokenizer = get_tokenizer(cfg) |
|
df = create_batch_prediction_df(batch, tokenizer) |
|
path = os.path.join(cfg.output_directory, "batch_viz.parquet") |
|
df.to_parquet(path) |
|
return PlotData(path, encoding="df") |
|
|
|
@classmethod |
|
def plot_data(cls, cfg) -> PlotData: |
|
""" |
|
Plots the data in a scrollable table. |
|
We limit the number of rows to max 600 to avoid rendering issues in Wave. |
|
As the data visualization is instantiated on every page load, we cache the |
|
data visualization in a parquet file. |
|
""" |
|
config_id = ( |
|
str(cfg.dataset.train_dataframe) |
|
+ str(cfg.dataset.system_column) |
|
+ str(cfg.dataset.prompt_column) |
|
+ str(cfg.dataset.answer_column) |
|
+ str(cfg.dataset.parent_id_column) |
|
) |
|
config_hash = hashlib.md5(config_id.encode()).hexdigest() |
|
path = os.path.join( |
|
os.path.dirname(cfg.dataset.train_dataframe), |
|
f"__meta_info__{config_hash}_data_viz.parquet", |
|
) |
|
if os.path.exists(path): |
|
return PlotData(path, encoding="df") |
|
|
|
df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg) |
|
|
|
conversations = get_conversation_chains(df, cfg, limit_chained_samples=True) |
|
|
|
|
|
|
|
max_conversation_length = min( |
|
max([len(conversation["prompts"]) for conversation in conversations]), 15 |
|
) |
|
|
|
conversations_to_display = [] |
|
for conversation_length in range(1, max_conversation_length + 1): |
|
conversations_to_display += [ |
|
conversation |
|
for conversation in conversations |
|
if len(conversation["prompts"]) == conversation_length |
|
][:5] |
|
|
|
|
|
df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"]) |
|
|
|
i = 0 |
|
for sample_number, conversation in enumerate(conversations_to_display): |
|
if conversation["systems"][0] != "": |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"System", |
|
conversation["systems"][0], |
|
] |
|
i += 1 |
|
for prompt, answer in zip(conversation["prompts"], conversation["answers"]): |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"Prompt", |
|
prompt, |
|
] |
|
i += 1 |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"Answer", |
|
answer, |
|
] |
|
i += 1 |
|
|
|
df_transposed["Content"] = df_transposed["Content"].apply( |
|
format_for_markdown_visualization |
|
) |
|
|
|
df_transposed.to_parquet(path) |
|
|
|
return PlotData(path, encoding="df") |
|
|
|
@classmethod |
|
def plot_validation_predictions( |
|
cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str |
|
) -> PlotData: |
|
return plot_validation_predictions(val_outputs, cfg, val_df, mode) |
|
|
|
|
|
def plot_validation_predictions( |
|
val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str |
|
) -> PlotData: |
|
conversations = get_conversation_chains( |
|
val_df, cfg, limit_chained_samples=cfg.dataset.limit_chained_samples |
|
) |
|
prompt_column_name = ( |
|
cfg.dataset.prompt_column |
|
if len(cfg.dataset.prompt_column) > 1 |
|
else cfg.dataset.prompt_column[0] |
|
) |
|
|
|
target_texts = [conversation["answers"][-1] for conversation in conversations] |
|
|
|
input_texts = [] |
|
for conversation in conversations: |
|
input_text = conversation["systems"][0] |
|
prompts = conversation["prompts"] |
|
answers = conversation["answers"] |
|
|
|
answers[-1] = "" |
|
for prompt, answer in zip(prompts, answers): |
|
input_text += ( |
|
f" **{prompt_column_name}:** " |
|
f"{prompt}\n\n" |
|
f"**{cfg.dataset.answer_column}:** " |
|
f"{answer}\n\n" |
|
) |
|
input_texts += [input_text] |
|
|
|
if "predicted_text" in val_outputs.keys(): |
|
predicted_texts = val_outputs["predicted_text"] |
|
else: |
|
predicted_texts = [ |
|
"No predictions are generated for the selected metric" |
|
] * len(target_texts) |
|
|
|
input_text_column_name = ( |
|
"Input Text (tokenization max length setting " |
|
"may truncate the input text during training/inference)" |
|
) |
|
df = pd.DataFrame( |
|
{ |
|
input_text_column_name: input_texts, |
|
"Target Text": target_texts, |
|
"Predicted Text": predicted_texts, |
|
} |
|
) |
|
df[input_text_column_name] = df[input_text_column_name].apply( |
|
format_for_markdown_visualization |
|
) |
|
df["Target Text"] = df["Target Text"].apply(format_for_markdown_visualization) |
|
df["Predicted Text"] = df["Predicted Text"].apply(format_for_markdown_visualization) |
|
|
|
if val_outputs.get("metrics") is not None: |
|
metric_column_name = f"Metric ({cfg.prediction.metric})" |
|
df[metric_column_name] = val_outputs["metrics"] |
|
df[metric_column_name] = df[metric_column_name].round(decimals=3) |
|
if len(df) > 900: |
|
df.sort_values(by=metric_column_name, inplace=True) |
|
df = pd.concat( |
|
[ |
|
df.iloc[:300], |
|
df.iloc[300:-300].sample(n=300, random_state=42), |
|
df.iloc[-300:], |
|
] |
|
).reset_index(drop=True) |
|
|
|
elif len(df) > 900: |
|
df = df.sample(n=900, random_state=42).reset_index(drop=True) |
|
|
|
if val_outputs.get("explanations") is not None: |
|
df["Explanation"] = val_outputs["explanations"] |
|
|
|
path = os.path.join(cfg.output_directory, f"{mode}_viz.parquet") |
|
df.to_parquet(path) |
|
return PlotData(data=path, encoding="df") |
|
|
|
|
|
def create_batch_prediction_df( |
|
batch, tokenizer, ids_for_tokenized_text="input_ids", labels_column="labels" |
|
): |
|
df = pd.DataFrame( |
|
{ |
|
"Prompt Text": [ |
|
tokenizer.decode(input_ids, skip_special_tokens=True) |
|
for input_ids in batch["prompt_input_ids"].detach().cpu().numpy() |
|
] |
|
} |
|
) |
|
df["Prompt Text"] = df["Prompt Text"].apply(format_for_markdown_visualization) |
|
if labels_column in batch.keys(): |
|
df["Answer Text"] = [ |
|
tokenizer.decode( |
|
[label for label in labels if label != -100], |
|
skip_special_tokens=True, |
|
) |
|
for labels in batch.get(labels_column, batch[ids_for_tokenized_text]) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
] |
|
tokens_list = [ |
|
tokenizer.convert_ids_to_tokens(input_ids) |
|
for input_ids in batch[ids_for_tokenized_text].detach().cpu().numpy() |
|
] |
|
masks_list = [ |
|
[label != -100 for label in labels] |
|
for labels in batch.get(labels_column, batch[ids_for_tokenized_text]) |
|
.detach() |
|
.cpu() |
|
.numpy() |
|
] |
|
df["Tokenized Text"] = [ |
|
list_to_markdown_representation( |
|
tokens, masks, pad_token=tokenizer.pad_token, num_chars=100 |
|
) |
|
for tokens, masks in zip(tokens_list, masks_list) |
|
] |
|
|
|
df = df.iloc[:2000] |
|
|
|
df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"]) |
|
has_answer = "Answer Text" in df.columns |
|
for i, row in df.iterrows(): |
|
offset = 2 + int(has_answer) |
|
df_transposed.loc[i * offset] = [ |
|
i, |
|
"Prompt Text", |
|
row["Prompt Text"], |
|
] |
|
if has_answer: |
|
df_transposed.loc[i * offset + 1] = [ |
|
i, |
|
"Answer Text", |
|
row["Answer Text"], |
|
] |
|
df_transposed.loc[i * offset + 1 + int(has_answer)] = [ |
|
i, |
|
"Tokenized Text", |
|
row["Tokenized Text"], |
|
] |
|
return df_transposed |
|
|