import os from llm_studio.src.datasets.text_utils import get_tokenizer from llm_studio.src.plots.text_causal_language_modeling_plots import ( Plots as TextCausalLanguageModelingPlots, ) from llm_studio.src.plots.text_causal_language_modeling_plots import ( create_batch_prediction_df, ) from llm_studio.src.utils.plot_utils import PlotData class Plots(TextCausalLanguageModelingPlots): @classmethod def plot_batch(cls, batch, cfg) -> PlotData: tokenizer = get_tokenizer(cfg) df = create_batch_prediction_df( batch, tokenizer, ids_for_tokenized_text="prompt_input_ids" ) path = os.path.join(cfg.output_directory, "batch_viz.parquet") df.to_parquet(path) return PlotData(path, encoding="df")