import Plotly from 'plotly.js-basic-dist-min'; import Papa from 'papaparse'; import _ from 'lodash'; import { getColor } from './colors.mjs'; const languageMap = { 'Arabic': 'ar', 'Turkish': 'tr', 'Swahili': 'sw', 'Russian': 'ru', 'Telugu': 'te', 'Thai': 'th', 'Chinese': 'zh', 'French': 'fr', 'Hindi': 'hi' }; const runNameMap = { "orion": "Dataset-A", "helios": "Dataset-B", "lynx": "Dataset-C", "aquila": "Dataset-D", "commoncrawl": "Dataset-E", "baseline": "Baseline" }; const taskLists = { ar: ['acva_ara:_average', 'alfgahafa_mlqa_ara_cf', 'alghafa_arc_ara_cf:easy', 'alghafa_facts_ara_cf', 'alghafa_meta_dialects_ara_cf', 'alghafa_mmlu_ara_cf:_average', 'alghafa_openbookqa_ara_cf', 'alghafa_piqa_ara_cf', 'alghafa_race_ara_cf', 'alghafa_rating_sentiment_ara_cf', 'alghafa_rating_sentiment_no_neutral_ara_cf', 'alghafa_sciqa_ara_cf', 'alghafa_sentiment_ara_cf', 'arcd_ara', 'belebele_arb_Arab_cf', 'boolq_ara', 'exams_ara_cf:_average', 'mkqa_ara:_average', 'mlmm_arc_ara_cf:challenge', 'mlmm_hellaswag_ara_cf', 'mlmm_mmlu_ara_cf:_average', 'mlmm_truthfulqa_ara_cf:mc1', 'mlmm_truthfulqa_ara_cf:mc2', 'mlqa_ara', 'mmlu_ara_cf:_average', 'soqal_ara_cf', 'toxigen_ara_cf', 'tydiqa_ara', 'xcodah_ara_cf', 'xcopa_ara_cf', 'xcsqa_ara_cf', 'xnli2.0_ara_cf', 'xnli_ara_cf', 'xquad_ara', 'xstory_cloze_ara_cf'], fr: ['belebele_fra_Latn_cf', 'community_boolq_fra_cf', 'exams_fra_cf:_average', 'fquadv2_fra', 'frenchbench_arc_fra_cf:challenge', 'frenchbench_hellaswag_fra_cf', 'meta_mmlu_fra_cf:_average', 'mintaka_fra', 'mkqa_fra:_average', 'mlmm_arc_fra_cf:challenge', 'mlmm_hellaswag_fra_cf', 'mlmm_mmlu_fra_cf:_average', 'mlmm_truthfulqa_fra_cf:mc1', 'mlmm_truthfulqa_fra_cf:mc2', 'pawsx_fra_cf', 'xcodah_fra_cf', 'xcsqa_fra_cf', 'xnli2.0_fra_cf', 'xwinograd_fra_cf'], hi: ['belebele_hin_Deva_cf', 'community_arc_hin_cf:challenge', 'community_arc_hin_cf:easy', 'community_boolq_hin', 'community_hellaswag_hin_cf', 'indicnxnli_hin_cf', 'indicqa_hin', 'indicxcopa_hin_cf', 'meta_mmlu_hin_cf:_average', 'mintaka_hin', 'mlmm_arc_hin_cf:challenge', 'mlmm_hellaswag_hin_cf', 'mlmm_mmlu_hin_cf:_average', 'mlmm_truthfulqa_hin_cf:mc1', 'mlmm_truthfulqa_hin_cf:mc2', 'mlqa_hin', 'xcodah_hin_cf', 'xcsqa_hin_cf', 'xnli2.0_hin_cf', 'xnli_hin_cf', 'xquad_hin', 'xstory_cloze_hin_cf'], ru: ['belebele_rus_Cyrl_cf', 'chegeka_rus', 'mathlogic_qa_rus_cf', 'mera_openbookqa_rus_cf', 'mera_worldtree_rus_cf', 'mkqa_rus:_average', 'mlmm_arc_rus_cf:challenge', 'mlmm_hellaswag_rus_cf', 'mlmm_mmlu_rus_cf:_average', 'mlmm_truthfulqa_rus_cf:mc1', 'mlmm_truthfulqa_rus_cf:mc2', 'parus_rus_cf', 'rcb_rus_cf', 'rummlu_rus_cf:_average', 'sber_squad_rus', 'tydiqa_rus', 'xcodah_rus_cf', 'xcsqa_rus_cf', 'xnli2.0_rus_cf', 'xquad_rus', 'xstory_cloze_rus_cf', 'xwinograd_rus_cf'], sw: ['afric_mmlu_swa_cf:_average', 'afric_xnli_swa_cf', 'belebele_swh_Latn_cf', 'community_arc_swa_cf:challenge', 'community_arc_swa_cf:easy', 'community_mmlu_swa_cf', 'kenswquad_swa', 'm3exams_swa_cf', 'openai_mmlu_swa_cf:_average', 'tydiqa_swa', 'xcodah_swa_cf', 'xcopa_swa_cf', 'xcsqa_swa_cf', 'xnli2.0_swa_cf', 'xnli_swa_cf', 'xstory_cloze_swa_cf'], te: ['belebele_tel_Telu_cf', 'community_hellaswag_tel_cf', 'indicnxnli_tel_cf', 'indicqa_tel', 'indicxcopa_tel_cf', 'mlmm_arc_tel_cf:challenge', 'mlmm_hellaswag_tel_cf', 'mlmm_mmlu_tel_cf:_average', 'mlmm_truthfulqa_tel_cf:mc1', 'mlmm_truthfulqa_tel_cf:mc2', 'tydiqa_tel', 'xstory_cloze_tel_cf'], th: ['belebele_tha_Thai_cf', 'community_hellaswag_tha_cf', 'm3exams_tha_cf', 'meta_mmlu_tha_cf:_average', 'mkqa_tha:_average', 'thai_exams_tha_cf:_average', 'thai_exams_tha_cf:tgat', 'thaiqa_tha', 'wsci_tha_cf', 'xcopa_tha_cf', 'xnli2.0_tha_cf', 'xnli_tha_cf', 'xquad_tha'], tr: ['belebele_tur_Latn_cf', 'community_arc_tur_cf:easy', 'community_hellaswag_tur_cf', 'community_mmlu_tur_cf:_average', 'community_truthfulqa_tur_cf:mc1', 'community_truthfulqa_tur_cf:mc2', 'community_xwinograd_tur_cf', 'exams_tur_cf:_average', 'mkqa_tur:_average', 'tquadv2_tur', 'xcopa_tur_cf', 'xnli2.0_tur_cf', 'xnli_tur_cf', 'xquad_tur'], zh: ['agieval_zho_cf:_average', 'belebele_zho_Hans_cf', 'c3_zho_cf', 'ceval_zho_cf:_average', 'chinese_squad_zho', 'cmath_zho_cf', 'cmmlu_zho_cf:_average', 'cmnli_zho_cf', 'cmrc2018_zho', 'm3exams_zho_cf', 'mkqa_zho:_average', 'mlmm_arc_zho_cf:challenge', 'mlmm_hellaswag_zho_cf', 'mlmm_mmlu_zho_cf:_average', 'mlmm_truthfulqa_zho_cf:mc1', 'mlmm_truthfulqa_zho_cf:mc2', 'ocnli_zho_cf', 'pawsx_zho_cf', 'xcodah_zho_cf', 'xcopa_zho_cf', 'xcsqa_zho_cf', 'xnli2.0_zho_cf', 'xnli_zho_cf', 'xquad_zho', 'xstory_cloze_zho_cf', 'xwinograd_zho_cf'] }; const LINE_SETTINGS = { width: 2.5, type: "scatter", mode: "lines+markers", }; const DEFAULT_LAYOUT = { font: { family: "apple-system, Arial, sans-serif", }, title: { font: { size: 15, }, }, xaxis: { title: { text: "Training Tokens (billions)", font: { size: 14, }, }, tickfont: { size: 12, }, showgrid: false, mirror: true, ticks: "outside", showline: true, }, yaxis: { title: { font: { size: 14, }, standoff: 10, }, showgrid: false, mirror: true, ticks: "outside", showline: true, tickfont: { size: 12, }, }, height: 300, // You can adjust this value autosize: true, legend: { orientation: 'h', // Set to 'h' for horizontal legend (required for columns) yanchor: 'bottom', y: 0, // Position at the bottom xanchor: 'right', x: 1, // Position at the right traceorder: 'normal', font: { size: 12 }, tracegroupgap: 0, // Space between legend items bgcolor: 'rgba(255, 255, 255, 0.8)' // White background with 70% transparency (1 - 0.3 = 70%) }, margin: { t: 25, b: 60, l: 60, r: 40, }, }; export function initPlotApplets() { const plotContainers = document.querySelectorAll('.task-signal-plot'); plotContainers.forEach(container => { initPlotApplet(container); }); } function initPlotApplet(container) { const defaultLanguage = container.dataset.language || 'Arabic'; const defaultTask = container.dataset.task || ''; const defaultMetric = container.dataset.metric || ''; const groupSeeds = container.dataset.groupSeeds === 'true'; const showControls = container.dataset.showControls === 'true'; const taskMetrics = (container.dataset.taskMetrics || 'monotonicity,snr,ordering,randomness').split(","); const controls = createControls(container, defaultLanguage, defaultTask, defaultMetric, taskMetrics); if (!showControls) controls.style.display = 'none'; container.appendChild(controls); const plotContainer = document.createElement('div'); plotContainer.className = 'plot-container'; container.appendChild(plotContainer); const statsContainer = document.createElement('div'); statsContainer.className = 'stats-container'; container.appendChild(statsContainer); // Create an initial empty plot Plotly.newPlot(plotContainer, []); // Set up the resize function const resizePlot = () => { const width = container.offsetWidth; Plotly.relayout(plotContainer, { width: width }); }; // Add resize listener window.addEventListener('resize', resizePlot); // Initial resize resizePlot(); // Load the initial data updateLanguageTasks(container, defaultTask, defaultMetric, groupSeeds, taskMetrics); } function createControls(container, defaultLanguage, defaultTask, defaultMetric, taskMetrics) { const controls = document.createElement('div'); controls.className = 'controls'; const languageSelect = createSelect('language', Object.keys(languageMap), () => updateLanguageTasks(container, '', '', true, taskMetrics)); languageSelect.value = defaultLanguage; const taskSelect = createSelect('task', [], () => updateMetrics(container, '', true, taskMetrics)); const metricSelect = createSelect('metric', [], () => updatePlot(container, taskMetrics)); controls.appendChild(createControlGroup('Language:', languageSelect)); controls.appendChild(createControlGroup('Task:', taskSelect)); controls.appendChild(createControlGroup('Metric:', metricSelect)); return controls; } function createSelect(id, options, onChangeHandler) { const select = document.createElement('select'); select.id = id; options.forEach(option => { const optionElement = document.createElement('option'); optionElement.value = option; optionElement.textContent = option; select.appendChild(optionElement); }); select.addEventListener('change', onChangeHandler); return select; } function createControlGroup(labelText, inputElement) { const group = document.createElement('div'); group.className = 'control-group'; const label = document.createElement('label'); label.textContent = labelText; label.className = 'control-label'; group.appendChild(label); group.appendChild(inputElement); return group; } async function updateLanguageTasks(container, defaultTask = '', defaultMetric = '', groupSeeds, taskMetrics) { const languageSelect = container.querySelector('#language'); const taskSelect = container.querySelector('#task'); const language = languageSelect.value; const langCode = languageMap[language]; taskSelect.innerHTML = ''; try { const tasks = await getTasksForLanguage(langCode); taskSelect.innerHTML = ''; if (tasks.length > 0) { tasks.forEach(task => { const option = document.createElement('option'); option.value = task; option.textContent = truncateText(task, 25); // Reduced from 30 to 25 option.title = task; // Set full task name as title for tooltip taskSelect.appendChild(option); }); if (defaultTask && tasks.includes(defaultTask)) { taskSelect.value = defaultTask; } else { taskSelect.selectedIndex = 0; } await updateMetrics(container, defaultMetric, groupSeeds, taskMetrics); } else { taskSelect.innerHTML = ''; clearPlot(container); } } catch (error) { console.error('Error fetching tasks:', error); taskSelect.innerHTML = ''; clearPlot(container); } } async function getTasksForLanguage(langCode) { return taskLists[langCode] || []; } async function updateMetrics(container, defaultMetric = '', groupSeeds, taskMetrics) { const language = container.querySelector('#language').value; const task = container.querySelector('#task').value; const langCode = languageMap[language]; const metricSelect = container.querySelector('#metric'); metricSelect.innerHTML = ''; try { const metrics = await getMetricsForTask(langCode, task); metricSelect.innerHTML = ''; metrics.forEach(metric => { const option = document.createElement('option'); option.value = metric; option.textContent = metric; metricSelect.appendChild(option); }); if (defaultMetric && metrics.includes(defaultMetric)) { metricSelect.value = defaultMetric; } else if (metricSelect.options.length > 0) { metricSelect.selectedIndex = 0; } await updatePlot(container, taskMetrics); } catch (error) { console.error('Error fetching metrics:', error); metricSelect.innerHTML = ''; clearPlot(container); } } async function getMetricsForTask(langCode, task) { return new Promise((resolve, reject) => { Papa.parse(`data/nanotron_tasks/${langCode}/${task}_stats.csv`, { download: true, header: true, complete: function(results) { const metrics = [...new Set(results.data.map(row => row.metric))]; resolve(metrics); }, error: function(error) { console.error('Error fetching metrics:', error); reject(error); } }); }); } function updatePlot(container, taskMetrics) { const language = container.querySelector('#language').value; const task = container.querySelector('#task').value; const metric = container.querySelector('#metric').value; const title = container.dataset.title; const langCode = languageMap[language]; if (!langCode || !task || !metric) { clearPlot(container); return; } const dataUrl = `data/nanotron_tasks/${langCode}/${task}_data.csv`; const statsUrl = `data/nanotron_tasks/${langCode}/${task}_stats.csv`; Promise.all([ new Promise((resolve, reject) => { Papa.parse(dataUrl, { download: true, header: true, dynamicTyping: true, complete: resolve, error: reject }); }), new Promise((resolve, reject) => { Papa.parse(statsUrl, { download: true, header: true, dynamicTyping: true, complete: resolve, error: reject }); }) ]).then(([dataResult, statsResult]) => { const taskData = dataResult.data; const statsData = statsResult.data; plotData(container, taskData, statsData, metric, title, taskMetrics); }).catch(error => { console.error('Error parsing CSV:', error); clearPlot(container); }); } function plotData(container, data, stats, metric, title, taskMetrics) { const groupSeeds = container.dataset.groupSeeds === 'true'; const sortedData = sortDataByTokens(data); const groupedData = groupDataByRunname(sortedData, groupSeeds, metric); const interpolatedData = interpolateData(groupedData, metric); const smoothedData = smoothData(interpolatedData, metric); const traces = createTraces(smoothedData, metric); const plotContainer = container.querySelector('.plot-container'); const layout = _.merge({}, DEFAULT_LAYOUT, { title: { text: `${title}` }, xaxis: { title: { text: 'Training Tokens (billions)' }, tickvals: [0, 5, 10, 15, 20, 25], ticktext: ['0', '5B', '10B', '15B', '20B', '25B'], tickangle: 45, range: [0, 30], // Set the range to start from 0 and end at 30B }, yaxis: { title: { text: 'Score' }, range: [Math.min(...traces.flatMap(trace => trace.y)) * 0.95, Math.max(...traces.flatMap(trace => trace.y)) * 1.05], // Add 5% padding to the top and bottom }, width: container.offsetWidth, }); Plotly.newPlot(plotContainer, traces, layout, {responsive: true}); // Display statistics displayStatistics(container, stats, metric, taskMetrics); } function displayStatistics(container, stats, metric, taskMetrics) { const statsContainer = container.querySelector('.stats-container'); const metricStats = stats.find(stat => stat.metric === metric); if (metricStats) { statsContainer.innerHTML = `
${taskMetrics.includes('monotonicity') ? 'Monotonicity: ' + metricStats.avg_spearman.toFixed(2) + '' : ''} ${taskMetrics.includes('snr') ? 'Signal-to-Noise: ' + metricStats.avg_snr.toFixed(2) + '' : ''} ${taskMetrics.includes('ordering') ? 'Ordering Consistency: ' + metricStats.avg_kendall_tau_a.toFixed(2) + '' : ''} ${taskMetrics.includes('randomness') ? 'Non-Randomness: ' + metricStats.max_n_std.toFixed(2) + '' : ''}
`; } else { statsContainer.innerHTML = '

No statistics available for this metric.

'; } } function getReducedTickValues(tokens) { const uniqueTokens = [...new Set(tokens)].sort((a, b) => a - b); const tokenCount = uniqueTokens.length; const targetTickCount = 10; // Adjust this value to increase/decrease the number of ticks if (tokenCount <= targetTickCount) { return uniqueTokens; } const stride = Math.ceil(tokenCount / targetTickCount); return uniqueTokens.filter((_, index) => index % stride === 0); } function formatTickLabel(value) { if (value >= 1e9) { return (value / 1e9).toFixed(1) + 'B'; } else if (value >= 1e6) { return (value / 1e6).toFixed(1) + 'M'; } else if (value >= 1e3) { return (value / 1e3).toFixed(1) + 'K'; } return value.toString(); } function computeStatistics(data, metric) { const stats = { avg_spearman: 0, avg_kendall_tau_a: 0, avg_snr: 0, max_n_std: 0 }; const baselineRun = Object.keys(data).find(key => key.toLowerCase().includes('baseline')); const nonBaselineRuns = Object.keys(data).filter(key => key !== baselineRun); // Compute statistics for each non-baseline run nonBaselineRuns.forEach(run => { const runData = data[run]; const tokens = runData.map(row => row.tokens); const scores = runData.map(row => row[metric]); // Spearman correlation stats.avg_spearman += spearmanCorrelation(tokens, scores); // Kendall Tau-a const lastHalf = Math.floor(runData.length / 2); const kendallTauValues = []; for (let i = lastHalf; i < runData.length - 1; i++) { kendallTauValues.push(kendallTauA(scores.slice(0, i + 1), scores.slice(0, i + 2))); } stats.avg_kendall_tau_a += _.mean(kendallTauValues); // SNR and max_n_std if (baselineRun) { const baselineScores = data[baselineRun].map(row => row[metric]); const stdDev = standardDeviation(scores); stats.avg_snr += _.mean(scores) / stdDev; stats.max_n_std = Math.max(stats.max_n_std, (_.max(scores) - _.mean(baselineScores)) / stdDev); } }); // Average the statistics const numRuns = nonBaselineRuns.length; stats.avg_spearman /= numRuns; stats.avg_kendall_tau_a /= numRuns; stats.avg_snr /= numRuns; return stats; } function spearmanCorrelation(x, y) { const n = x.length; const rankX = rankData(x); const rankY = rankData(y); let sum_d_squared = 0; for (let i = 0; i < n; i++) { const d = rankX[i] - rankY[i]; sum_d_squared += d * d; } return 1 - (6 * sum_d_squared) / (n * (n * n - 1)); } function rankData(data) { const sorted = [...data].sort((a, b) => a - b); return data.map(x => sorted.indexOf(x) + 1); } function kendallTauA(x, y) { const n = x.length; let concordant = 0; let discordant = 0; for (let i = 0; i < n; i++) { for (let j = i + 1; j < n; j++) { const sign_x = Math.sign(x[j] - x[i]); const sign_y = Math.sign(y[j] - y[i]); if (sign_x * sign_y > 0) concordant++; else if (sign_x * sign_y < 0) discordant++; } } return (concordant - discordant) / (n * (n - 1) / 2); } function standardDeviation(values) { const mean = _.mean(values); const squareDiffs = values.map(value => { const diff = value - mean; return diff * diff; }); const avgSquareDiff = _.mean(squareDiffs); return Math.sqrt(avgSquareDiff); } function interpolateData(data, metric) { return _.mapValues(data, (rows) => { const sortedRows = _.sortBy(rows, 'tokens'); const allTokens = _.uniq(_.flatMap(Object.values(data), rows => rows.map(r => r.tokens))).sort((a, b) => a - b); return allTokens.map(token => { const exactMatch = _.find(sortedRows, { tokens: token }); if (exactMatch) return exactMatch; const lowerRow = _.findLast(sortedRows, r => r.tokens < token); const upperRow = _.find(sortedRows, r => r.tokens > token); if (!lowerRow) return { ...upperRow, tokens: token }; if (!upperRow) return { ...lowerRow, tokens: token }; const ratio = (token - lowerRow.tokens) / (upperRow.tokens - lowerRow.tokens); const interpolatedMetric = lowerRow[metric] + (upperRow[metric] - lowerRow[metric]) * ratio; return { ...lowerRow, tokens: token, [metric]: interpolatedMetric }; }); }); } function smoothData(data, metric, windowSize = 3) { return _.mapValues(data, (rows) => { return rows.map((row, index, array) => { const window = array.slice(Math.max(0, index - windowSize + 1), index + 1); const smoothedMetric = _.meanBy(window, r => r[metric]); return { ...row, [metric]: smoothedMetric }; }); }); } function sortDataByTokens(data) { return _.sortBy(data, 'tokens'); } function groupDataByRunname(data, groupSeeds, metric) { // Remove null or undefined runs data = data.filter(row => row.runname != null && row.runname !== 'null_undefined'); if (!groupSeeds) { return _.groupBy(data, row => `${processRunName(row.runname)}_${row.seed}`); } const grouped = _.groupBy(data, row => processRunName(row.runname)); return _.mapValues(grouped, (rows) => { const stepGroups = _.groupBy(rows, 'tokens'); return _.map(stepGroups, (stepRows) => { const meanMetric = _.meanBy(stepRows, row => parseFloat(row[metric]) || 0); return { ...stepRows[0], [metric]: meanMetric }; }); }); } function processRunName(runname) { for (const [key, value] of Object.entries(runNameMap)) { if (runname.includes(key)) { return value; } } return runname; } function createTraces(groupedData, metric) { const colorsMapping = new Map(); const sortedRunnames = Object.keys(groupedData).sort((a, b) => { if (a.includes('baseline')) return 1; if (b.includes('baseline')) return -1; return a.localeCompare(b); }); return sortedRunnames.map((runname, index) => { const color = getColorForTrace(runname, colorsMapping, index); return { x: groupedData[runname].map(row => row.tokens), y: groupedData[runname].map(row => row[metric]), name: runname, line: { color: color, shape: 'spline', ...LINE_SETTINGS }, marker: { color: color, size: 6, }, mode: 'lines+markers', }; }); } function getColorForTrace(traceName, colorsMapping, index) { const reusedColor = colorsMapping.get(traceName); if (reusedColor) { return reusedColor; } const color = getColor(index); colorsMapping.set(traceName, color); return color; } function clearPlot(container) { const plotContainer = container.querySelector('.plot-container'); Plotly.purge(plotContainer); } function truncateText(text, maxLength) { if (text.length <= maxLength) return text; return text.substr(0, maxLength - 2) + '..'; }