|
import hashlib |
|
import os |
|
from typing import Any, Dict, List |
|
|
|
import pandas as pd |
|
|
|
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains |
|
from llm_studio.src.datasets.text_utils import get_tokenizer |
|
from llm_studio.src.plots.text_causal_language_modeling_plots import ( |
|
create_batch_prediction_df, |
|
plot_validation_predictions, |
|
) |
|
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels |
|
from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization |
|
from llm_studio.src.utils.utils import PatchedAttribute |
|
|
|
|
|
class Plots: |
|
@classmethod |
|
def plot_batch(cls, batch, cfg) -> PlotData: |
|
tokenizer = get_tokenizer(cfg) |
|
df = create_batch_prediction_df( |
|
batch, |
|
tokenizer, |
|
ids_for_tokenized_text="chosen_input_ids", |
|
labels_column="chosen_labels", |
|
) |
|
path = os.path.join(cfg.output_directory, "batch_viz.parquet") |
|
df.to_parquet(path) |
|
return PlotData(path, encoding="df") |
|
|
|
@classmethod |
|
def plot_data(cls, cfg) -> PlotData: |
|
""" |
|
Plots the data in a scrollable table. |
|
We limit the number of rows to max 600 to avoid rendering issues in Wave. |
|
As the data visualization is instantiated on every page load, we cache the |
|
data visualization in a parquet file. |
|
""" |
|
config_id = ( |
|
str(cfg.dataset.train_dataframe) |
|
+ str(cfg.dataset.system_column) |
|
+ str(cfg.dataset.prompt_column) |
|
+ str(cfg.dataset.answer_column) |
|
+ str(cfg.dataset.rejected_answer_column) |
|
+ str(cfg.dataset.parent_id_column) |
|
) |
|
config_hash = hashlib.md5(config_id.encode()).hexdigest() |
|
path = os.path.join( |
|
os.path.dirname(cfg.dataset.train_dataframe), |
|
f"__meta_info__{config_hash}_data_viz.parquet", |
|
) |
|
if os.path.exists(path): |
|
return PlotData(path, encoding="df") |
|
|
|
df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg) |
|
|
|
conversations_chosen = get_conversation_chains( |
|
df, cfg, limit_chained_samples=True |
|
) |
|
with PatchedAttribute( |
|
cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column |
|
): |
|
conversations_rejected = get_conversation_chains( |
|
df, cfg, limit_chained_samples=True |
|
) |
|
|
|
|
|
max_conversation_length = min( |
|
max( |
|
[len(conversation["prompts"]) for conversation in conversations_chosen] |
|
), |
|
15, |
|
) |
|
|
|
conversations_to_display: List = [] |
|
for conversation_length in range(1, max_conversation_length + 1): |
|
conversations_to_display += [ |
|
(conversation_chosen, conversations_rejected) |
|
for conversation_chosen, conversations_rejected in zip( |
|
conversations_chosen, conversations_rejected |
|
) |
|
if len(conversation_chosen["prompts"]) == conversation_length |
|
][:5] |
|
|
|
|
|
df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"]) |
|
|
|
i = 0 |
|
for sample_number, (conversation_chosen, conversations_rejected) in enumerate( |
|
conversations_to_display |
|
): |
|
if conversation_chosen["systems"][0] != "": |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"System", |
|
conversation_chosen["systems"][0], |
|
] |
|
i += 1 |
|
for prompt, answer_chosen, answer_rejected in zip( |
|
conversation_chosen["prompts"], |
|
conversation_chosen["answers"], |
|
conversations_rejected["answers"], |
|
): |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"Prompt", |
|
prompt, |
|
] |
|
i += 1 |
|
if answer_chosen == answer_rejected: |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"Answer", |
|
answer_chosen, |
|
] |
|
i += 1 |
|
else: |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"Answer Chosen", |
|
answer_chosen, |
|
] |
|
i += 1 |
|
df_transposed.loc[i] = [ |
|
sample_number, |
|
"Answer Rejected", |
|
answer_rejected, |
|
] |
|
i += 1 |
|
|
|
df_transposed["Content"] = df_transposed["Content"].apply( |
|
format_for_markdown_visualization |
|
) |
|
df_transposed.to_parquet(path) |
|
return PlotData(path, encoding="df") |
|
|
|
@classmethod |
|
def plot_validation_predictions( |
|
cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str |
|
) -> PlotData: |
|
return plot_validation_predictions(val_outputs, cfg, val_df, mode) |
|
|