H2OTest / llm_studio /src /plots /text_dpo_modeling_plots.py
elineve's picture
Upload 301 files
07423df
import hashlib
import os
from typing import Any, Dict, List
import pandas as pd
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains
from llm_studio.src.datasets.text_utils import get_tokenizer
from llm_studio.src.plots.text_causal_language_modeling_plots import (
create_batch_prediction_df,
plot_validation_predictions,
)
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels
from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization
from llm_studio.src.utils.utils import PatchedAttribute
class Plots:
@classmethod
def plot_batch(cls, batch, cfg) -> PlotData:
tokenizer = get_tokenizer(cfg)
df = create_batch_prediction_df(
batch,
tokenizer,
ids_for_tokenized_text="chosen_input_ids",
labels_column="chosen_labels",
)
path = os.path.join(cfg.output_directory, "batch_viz.parquet")
df.to_parquet(path)
return PlotData(path, encoding="df")
@classmethod
def plot_data(cls, cfg) -> PlotData:
"""
Plots the data in a scrollable table.
We limit the number of rows to max 600 to avoid rendering issues in Wave.
As the data visualization is instantiated on every page load, we cache the
data visualization in a parquet file.
"""
config_id = (
str(cfg.dataset.train_dataframe)
+ str(cfg.dataset.system_column)
+ str(cfg.dataset.prompt_column)
+ str(cfg.dataset.answer_column)
+ str(cfg.dataset.rejected_answer_column)
+ str(cfg.dataset.parent_id_column)
)
config_hash = hashlib.md5(config_id.encode()).hexdigest()
path = os.path.join(
os.path.dirname(cfg.dataset.train_dataframe),
f"__meta_info__{config_hash}_data_viz.parquet",
)
if os.path.exists(path):
return PlotData(path, encoding="df")
df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
conversations_chosen = get_conversation_chains(
df, cfg, limit_chained_samples=True
)
with PatchedAttribute(
cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column
):
conversations_rejected = get_conversation_chains(
df, cfg, limit_chained_samples=True
)
# Limit to max 15 prompt-conversation-answer rounds
max_conversation_length = min(
max(
[len(conversation["prompts"]) for conversation in conversations_chosen]
),
15,
)
conversations_to_display: List = []
for conversation_length in range(1, max_conversation_length + 1):
conversations_to_display += [
(conversation_chosen, conversations_rejected)
for conversation_chosen, conversations_rejected in zip(
conversations_chosen, conversations_rejected
)
if len(conversation_chosen["prompts"]) == conversation_length
][:5]
# Convert into a scrollable table by transposing the dataframe
df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"])
i = 0
for sample_number, (conversation_chosen, conversations_rejected) in enumerate(
conversations_to_display
):
if conversation_chosen["systems"][0] != "":
df_transposed.loc[i] = [
sample_number,
"System",
conversation_chosen["systems"][0],
]
i += 1
for prompt, answer_chosen, answer_rejected in zip(
conversation_chosen["prompts"],
conversation_chosen["answers"],
conversations_rejected["answers"], # type: ignore
):
df_transposed.loc[i] = [
sample_number,
"Prompt",
prompt,
]
i += 1
if answer_chosen == answer_rejected:
df_transposed.loc[i] = [
sample_number,
"Answer",
answer_chosen,
]
i += 1
else:
df_transposed.loc[i] = [
sample_number,
"Answer Chosen",
answer_chosen,
]
i += 1
df_transposed.loc[i] = [
sample_number,
"Answer Rejected",
answer_rejected,
]
i += 1
df_transposed["Content"] = df_transposed["Content"].apply(
format_for_markdown_visualization
)
df_transposed.to_parquet(path)
return PlotData(path, encoding="df")
@classmethod
def plot_validation_predictions(
cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str
) -> PlotData:
return plot_validation_predictions(val_outputs, cfg, val_df, mode)