chriskz commited on
Commit
6ed8453
Β·
verified Β·
1 Parent(s): b5811ab

v3: Comprehensive eval + research paper

Browse files
Files changed (1) hide show
  1. spectral_kv/eval_comprehensive.py +677 -0
spectral_kv/eval_comprehensive.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SpectralKV v3 β€” Comprehensive Evaluation Suite
4
+ ================================================
5
+ Large-document cache benchmark, generation quality drift, multi-turn
6
+ cache drift, throughput, TAFT (attention output perturbation), and
7
+ full comparison matrix across all methods.
8
+
9
+ Metrics measured:
10
+ β€’ Cache size (bytes, compression ratio, % saved)
11
+ β€’ Energy retention (total, low-freq, high-freq via rfft)
12
+ β€’ Attention output perturbation β€” L1 AOP (Ada-KV Thm 3.1)
13
+ β€’ Attention cosine similarity
14
+ β€’ Processing / scoring time (ms)
15
+ β€’ Generation throughput (tok/s, prefill + decode)
16
+ β€’ Output quality vs baseline (token-match, cosine embedding sim)
17
+ β€’ Normalised-delta perplexity ND-PPL
18
+ β€’ Multi-turn cache drift (cumulative AOP across N turns)
19
+ """
20
+
21
+ import torch, torch.nn.functional as F, math, time, json, os, gc, sys
22
+ from typing import Dict, List, Tuple, Optional
23
+ from dataclasses import dataclass, asdict
24
+ from tabulate import tabulate
25
+ import numpy as np
26
+
27
+ from spectral_kv.compressors import (
28
+ FourierKV, WaveletKV, WaveletFourierKV, WaveletTriAttention,
29
+ TriAttentionKV, TurboQuantKV, FullAttention, create_compressor,
30
+ _key_norms, _normalize,
31
+ )
32
+
33
+ # ═══════════════════════════ Helpers ══════════════════════════════
34
+
35
+ def _sync():
36
+ if torch.cuda.is_available():
37
+ torch.cuda.synchronize()
38
+
39
+ def _mem_mb():
40
+ if torch.cuda.is_available():
41
+ return torch.cuda.max_memory_allocated() / 1e6
42
+ return 0.0
43
+
44
+ def _clear():
45
+ gc.collect()
46
+ if torch.cuda.is_available():
47
+ torch.cuda.empty_cache()
48
+ torch.cuda.reset_peak_memory_stats()
49
+
50
+
51
+ # ═══════════════════════ Data helpers ═════════════════════════════
52
+
53
+ LONG_DOCUMENT = """
54
+ The theory of wavelet transforms has its roots in harmonic analysis, a branch
55
+ of mathematics concerned with the representation of functions in terms of basic
56
+ waves. Unlike the classical Fourier transform, which decomposes a signal into
57
+ globally supported sinusoidal functions, the wavelet transform uses localized
58
+ wave-like functions β€” wavelets β€” that are simultaneously concentrated in both
59
+ time and frequency. This dual localization property makes wavelets particularly
60
+ well-suited for analyzing non-stationary signals whose frequency content changes
61
+ over time.
62
+
63
+ The development of wavelet theory accelerated in the 1980s through the
64
+ contributions of Jean Morlet, Alex Grossmann, Yves Meyer, Ingrid Daubechies,
65
+ and StΓ©phane Mallat, among others. Morlet's work in seismic signal processing
66
+ revealed the limitations of short-time Fourier analysis, motivating the search
67
+ for a more flexible time-frequency decomposition. Grossmann and Morlet
68
+ formalized the continuous wavelet transform (CWT), establishing the mathematical
69
+ framework for wavelet analysis on the real line.
70
+
71
+ Daubechies' landmark contribution was the construction of compactly supported
72
+ orthonormal wavelet bases with prescribed numbers of vanishing moments. The
73
+ Daubechies wavelets, particularly db4 (with four vanishing moments), achieve
74
+ optimal support length for a given regularity, making them the standard choice
75
+ for signal compression in applications ranging from image coding (JPEG 2000)
76
+ to numerical analysis and, more recently, neural network compression.
77
+
78
+ In the context of large language models, the key-value (KV) cache presents a
79
+ natural signal-processing problem. During autoregressive generation, each
80
+ attention layer maintains a cache of key and value vectors that grows linearly
81
+ with sequence length. For sequences of length L with H attention heads and
82
+ dimension d per head, the KV cache occupies O(LHd) memory per layer. At
83
+ 32,768 tokens with a 32-layer, 32-head, 128-dimensional model, this amounts
84
+ to over 8 GB β€” often exceeding the model weights themselves.
85
+
86
+ Several families of KV cache compression methods have emerged:
87
+
88
+ 1. Token eviction (SnapKV, H2O, StreamingLLM): These methods maintain a
89
+ fixed-size cache by evicting tokens deemed unimportant based on attention
90
+ scores or positional heuristics. StreamingLLM preserves only sink tokens
91
+ (the first few positions) and a sliding window of recent tokens.
92
+
93
+ 2. Quantization (TurboQuant, KVQuant, GEAR): These reduce the bit-width of
94
+ cached KV vectors, typically from 16-bit to 4-bit or 2-bit representations.
95
+ Group-wise quantization with per-group scaling factors achieves good
96
+ reconstruction fidelity at 4-bit but degrades significantly at 2-bit.
97
+
98
+ 3. Structural methods (PyramidKV, TreeKV): These exploit the hierarchical
99
+ structure of attention patterns, allocating more cache to layers or
100
+ positions with higher information density.
101
+
102
+ 4. Spectral methods (SpectralKV, FreqKV): The most recent family, these
103
+ operate in the frequency domain on KV sequences. FreqKV applies DCT along
104
+ the sequence dimension and retains low-frequency coefficients, achieving
105
+ near-lossless compression at 50% retaining ratio. SpectralKV extends this
106
+ with wavelet transforms for multi-resolution analysis and hybrid
107
+ wavelet-Fourier scoring.
108
+
109
+ The key insight shared by all spectral methods is that the KV cache, viewed
110
+ as a sequence of d-dimensional vectors indexed by position, exhibits strong
111
+ frequency-domain structure. Low-frequency components encode global semantic
112
+ patterns (topic, style, discourse structure) that change slowly across the
113
+ sequence, while high-frequency components capture local token-level
114
+ variations (individual word importance, syntactic boundaries). This
115
+ separation motivates frequency-domain compression: by preserving the
116
+ dominant low-frequency structure and carefully managing high-frequency
117
+ detail, one can achieve high compression ratios with minimal impact on
118
+ generation quality.
119
+
120
+ The wavelet-Fourier hybrid approach in SpectralKV takes this further by
121
+ combining two complementary views of the signal. The Fourier (FFT) component
122
+ captures globally periodic patterns β€” the harmonic structure that Fourier
123
+ analysis excels at β€” while the wavelet (DWT) component captures localized
124
+ transients and multi-scale features that Fourier analysis misses. The
125
+ cascaded architecture first decomposes the signal via multi-level DWT, then
126
+ applies FFT within each wavelet scale band, producing a joint time-frequency
127
+ representation that is richer than either domain alone.
128
+ """
129
+
130
+ def get_long_text(target_tokens: int = 8192) -> str:
131
+ """Repeat the document to hit target token count (approximate)."""
132
+ approx_tokens_per_char = 0.3 # rough estimate
133
+ target_chars = int(target_tokens / approx_tokens_per_char)
134
+ text = LONG_DOCUMENT
135
+ while len(text) < target_chars:
136
+ text = text + "\n\n" + LONG_DOCUMENT
137
+ return text[:target_chars]
138
+
139
+
140
+ MULTI_TURN_QUESTIONS = [
141
+ "Summarize the key differences between Fourier and wavelet transforms.",
142
+ "What are Daubechies wavelets and why do they have four vanishing moments?",
143
+ "Explain the KV cache memory problem in large language models.",
144
+ "Compare token eviction methods like SnapKV with spectral methods.",
145
+ "What is the advantage of the cascaded wavelet-Fourier hybrid approach?",
146
+ ]
147
+
148
+
149
+ # ══════════════════════ Model loading ═════════════════════════════
150
+
151
+ def load_model(model_name="Qwen/Qwen2.5-0.5B-Instruct", device="cuda"):
152
+ from transformers import AutoModelForCausalLM, AutoTokenizer
153
+ print(f" Loading {model_name} …")
154
+ tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
155
+ mdl = AutoModelForCausalLM.from_pretrained(
156
+ model_name, dtype=torch.bfloat16, device_map=device,
157
+ trust_remote_code=True)
158
+ mdl.eval()
159
+ cfg = {
160
+ "n_layers": mdl.config.num_hidden_layers,
161
+ "n_kv_heads": mdl.config.num_key_value_heads,
162
+ "n_q_heads": mdl.config.num_attention_heads,
163
+ "head_dim": mdl.config.hidden_size // mdl.config.num_attention_heads,
164
+ }
165
+ return mdl, tok, cfg
166
+
167
+
168
+ # ═══════════════════ KV extraction + compression ═════════════════
169
+
170
+ def extract_kv(model, tokenizer, text, max_len=4096, device="cuda"):
171
+ """Full forward pass β†’ return DynamicCache with all layers' KV."""
172
+ inputs = tokenizer(text, return_tensors="pt",
173
+ max_length=max_len, truncation=True).to(device)
174
+ with torch.no_grad():
175
+ out = model(**inputs, use_cache=True)
176
+ return out.past_key_values, inputs["input_ids"]
177
+
178
+
179
+ def compress_cache(past_kv, compressor, budget, n_layers):
180
+ """In-place compress every layer; return timing."""
181
+ t0 = time.perf_counter()
182
+ for li in range(n_layers):
183
+ k = past_kv.layers[li].keys.float()
184
+ v = past_kv.layers[li].values.float()
185
+ compressor.calibrate(k)
186
+ if hasattr(compressor, 'compress') and hasattr(compressor, 'bits'):
187
+ ck, cv = compressor.compress(k, v)
188
+ else:
189
+ ck, cv, _ = compressor.prune(k, v, budget)
190
+ past_kv.layers[li].keys = ck.to(torch.bfloat16)
191
+ past_kv.layers[li].values = cv.to(torch.bfloat16)
192
+ _sync()
193
+ return (time.perf_counter() - t0) * 1000 # ms
194
+
195
+
196
+ def cache_bytes(past_kv, n_layers):
197
+ total = 0
198
+ for li in range(n_layers):
199
+ k = past_kv.layers[li].keys
200
+ v = past_kv.layers[li].values
201
+ total += k.nelement() * k.element_size()
202
+ total += v.nelement() * v.element_size()
203
+ return total
204
+
205
+
206
+ # ═══════════════════════ Attention Output Perturbation ════════════
207
+
208
+ def compute_aop(model, tokenizer, text, compressor, budget, cfg,
209
+ device="cuda", max_len=4096):
210
+ """
211
+ Attention Output Perturbation (AOP) β€” Ada-KV Theorem 3.1.
212
+ L1_AOP = mean over tokens of β€–o_full βˆ’ o_compressed‖₁
213
+ Returns per-layer AOP and aggregate.
214
+ """
215
+ inputs = tokenizer(text, return_tensors="pt",
216
+ max_length=max_len, truncation=True).to(device)
217
+
218
+ # --- full-cache output ---
219
+ with torch.no_grad():
220
+ out_full = model(**inputs, use_cache=True,
221
+ output_hidden_states=True)
222
+ hidden_full = out_full.hidden_states # tuple of [B, S, D] per layer
223
+ full_kv = out_full.past_key_values
224
+
225
+ # --- compressed-cache output ---
226
+ # re-run with fresh cache, then compress, then one more forward
227
+ with torch.no_grad():
228
+ out2 = model(**inputs, use_cache=True, output_hidden_states=True)
229
+ comp_kv = out2.past_key_values
230
+ comp_time = compress_cache(comp_kv, compressor, budget, cfg["n_layers"])
231
+
232
+ # Compare hidden states at last layer (output of attention stack)
233
+ h_full = hidden_full[-1].float() # [B, S, D]
234
+ h_comp = out2.hidden_states[-1].float()
235
+
236
+ # The AOP manifests when new queries attend over the compressed cache.
237
+ # Use a multi-token probe so the attention patterns actually differ.
238
+ probe_text = " The key insight is that wavelet decomposition captures"
239
+ probe = tokenizer(probe_text, return_tensors="pt",
240
+ add_special_tokens=False).to(device)
241
+ with torch.no_grad():
242
+ probe_full = model(**probe, past_key_values=full_kv,
243
+ output_hidden_states=True, use_cache=False)
244
+ probe_comp = model(**probe, past_key_values=comp_kv,
245
+ output_hidden_states=True, use_cache=False)
246
+
247
+ aop_per_layer = []
248
+ for li in range(len(probe_full.hidden_states)):
249
+ hf = probe_full.hidden_states[li].float()
250
+ hc = probe_comp.hidden_states[li].float()
251
+ l1 = (hf - hc).abs().mean().item()
252
+ aop_per_layer.append(l1)
253
+
254
+ return {
255
+ "aop_mean": np.mean(aop_per_layer),
256
+ "aop_max": np.max(aop_per_layer),
257
+ "aop_per_layer": aop_per_layer,
258
+ "compress_time_ms": comp_time,
259
+ }
260
+
261
+
262
+ # ════════════════ Generation quality comparison ══════════════════
263
+
264
+ def generate_with_cache(model, tokenizer, prompt_ids, past_kv,
265
+ max_new=64):
266
+ """Greedy decode from a pre-built cache. Returns text + timing."""
267
+ _sync()
268
+ t0 = time.perf_counter()
269
+ next_id = prompt_ids[:, -1:] # [1,1]
270
+ generated = [next_id.item()]
271
+ eos = tokenizer.eos_token_id
272
+ with torch.no_grad():
273
+ for _ in range(max_new):
274
+ out = model(next_id, past_key_values=past_kv, use_cache=True)
275
+ past_kv = out.past_key_values
276
+ next_id = out.logits[:, -1:].argmax(dim=-1) # greedy
277
+ tok = next_id.item()
278
+ generated.append(tok)
279
+ if tok == eos:
280
+ break
281
+ _sync()
282
+ elapsed = time.perf_counter() - t0
283
+ text = tokenizer.decode(generated, skip_special_tokens=True)
284
+ return text, len(generated), elapsed
285
+
286
+
287
+ def compare_generation(model, tokenizer, text, compressor, budget,
288
+ cfg, device="cuda", max_len=4096, gen_tokens=64):
289
+ """
290
+ Compares generation from full cache vs compressed cache.
291
+ Returns: token-match%, cosine embedding sim, tok/s.
292
+ """
293
+ inputs = tokenizer(text, return_tensors="pt",
294
+ max_length=max_len, truncation=True).to(device)
295
+ ids = inputs["input_ids"]
296
+
297
+ # --- full cache ---
298
+ with torch.no_grad():
299
+ out_full = model(**inputs, use_cache=True)
300
+ full_kv = out_full.past_key_values
301
+ full_text, full_n, full_t = generate_with_cache(
302
+ model, tokenizer, ids, full_kv, gen_tokens)
303
+ full_toks = full_n / (full_t + 1e-9)
304
+
305
+ # --- compressed cache ---
306
+ with torch.no_grad():
307
+ out_comp = model(**inputs, use_cache=True)
308
+ comp_kv = out_comp.past_key_values
309
+ comp_time = compress_cache(comp_kv, compressor, budget, cfg["n_layers"])
310
+ comp_text, comp_n, comp_t = generate_with_cache(
311
+ model, tokenizer, ids, comp_kv, gen_tokens)
312
+ comp_toks = comp_n / (comp_t + 1e-9)
313
+
314
+ # --- token match ---
315
+ full_ids = tokenizer.encode(full_text, add_special_tokens=False)
316
+ comp_ids = tokenizer.encode(comp_text, add_special_tokens=False)
317
+ min_len = min(len(full_ids), len(comp_ids))
318
+ if min_len > 0:
319
+ match = sum(a == b for a, b in zip(full_ids[:min_len],
320
+ comp_ids[:min_len])) / min_len
321
+ else:
322
+ match = 0.0
323
+
324
+ # --- cache sizes ---
325
+ with torch.no_grad():
326
+ out_ref = model(**inputs, use_cache=True)
327
+ full_bytes = cache_bytes(out_ref.past_key_values, cfg["n_layers"])
328
+ comp_bytes = cache_bytes(comp_kv, cfg["n_layers"])
329
+
330
+ return {
331
+ "token_match_pct": match * 100,
332
+ "full_tok_s": full_toks,
333
+ "comp_tok_s": comp_toks,
334
+ "speedup": comp_toks / (full_toks + 1e-9),
335
+ "compress_time_ms": comp_time,
336
+ "cache_full_mb": full_bytes / 1e6,
337
+ "cache_comp_mb": comp_bytes / 1e6,
338
+ "cache_ratio": full_bytes / max(comp_bytes, 1),
339
+ "cache_saved_pct": (1 - comp_bytes / full_bytes) * 100,
340
+ "full_text_sample": full_text[:200],
341
+ "comp_text_sample": comp_text[:200],
342
+ }
343
+
344
+
345
+ # ═══════════════════ Perplexity (ND-PPL) ═════════════════════════
346
+
347
+ def measure_ppl(model, tokenizer, text, compressor, budget, cfg,
348
+ device="cuda", max_len=4096):
349
+ """
350
+ Split text β†’ prefix + suffix.
351
+ Build cache from prefix, compress, evaluate NLL on suffix.
352
+ Returns PPL and ND-PPL (normalised delta vs full).
353
+ """
354
+ inputs = tokenizer(text, return_tensors="pt",
355
+ max_length=max_len, truncation=True).to(device)
356
+ ids = inputs["input_ids"]
357
+ split = ids.shape[1] // 2
358
+ prefix = ids[:, :split]
359
+ suffix = ids[:, split:]
360
+
361
+ # --- full ---
362
+ with torch.no_grad():
363
+ pf = model(prefix, use_cache=True)
364
+ full_kv = pf.past_key_values
365
+ with torch.no_grad():
366
+ sf = model(suffix, past_key_values=full_kv,
367
+ labels=suffix, use_cache=False)
368
+ ppl_full = math.exp(min(sf.loss.item(), 20))
369
+
370
+ # --- compressed ---
371
+ with torch.no_grad():
372
+ pc = model(prefix, use_cache=True)
373
+ comp_kv = pc.past_key_values
374
+ compress_cache(comp_kv, compressor, budget, cfg["n_layers"])
375
+ with torch.no_grad():
376
+ sc = model(suffix, past_key_values=comp_kv,
377
+ labels=suffix, use_cache=False)
378
+ ppl_comp = math.exp(min(sc.loss.item(), 20))
379
+
380
+ nd_ppl = (ppl_comp - ppl_full) / (ppl_full + 1e-8)
381
+
382
+ return {"ppl_full": ppl_full, "ppl_comp": ppl_comp, "nd_ppl": nd_ppl}
383
+
384
+
385
+ # ══════════════════ Multi-turn cache drift ════════════════════════
386
+
387
+ def multi_turn_drift(model, tokenizer, document, questions,
388
+ compressor, budget, cfg,
389
+ device="cuda", max_len=4096,
390
+ gen_tokens=48):
391
+ """
392
+ Simulate N turns of Q&A over a shared document context.
393
+ Track AOP / token-drift per turn for both full and compressed cache.
394
+
395
+ Protocol (inspired by SCBench):
396
+ Turn 0: prefill document β†’ cache
397
+ Turn k: append question_k, generate answer, keep growing cache
398
+ After each turn, measure deviation from full-cache answer.
399
+ """
400
+ doc_inputs = tokenizer(document, return_tensors="pt",
401
+ max_length=max_len, truncation=True).to(device)
402
+ doc_ids = doc_inputs["input_ids"]
403
+
404
+ # --- build full cache from document ---
405
+ with torch.no_grad():
406
+ out_full = model(doc_ids, use_cache=True)
407
+ full_kv = out_full.past_key_values
408
+
409
+ # --- build compressed cache ---
410
+ with torch.no_grad():
411
+ out_comp = model(doc_ids, use_cache=True)
412
+ comp_kv = out_comp.past_key_values
413
+ compress_cache(comp_kv, compressor, budget, cfg["n_layers"])
414
+
415
+ turns = []
416
+ for turn_i, q in enumerate(questions):
417
+ q_ids = tokenizer(f"\nQuestion: {q}\nAnswer:",
418
+ return_tensors="pt",
419
+ add_special_tokens=False).input_ids.to(device)
420
+
421
+ # --- full cache turn ---
422
+ with torch.no_grad():
423
+ fq = model(q_ids, past_key_values=full_kv, use_cache=True)
424
+ full_kv = fq.past_key_values
425
+ full_txt, _, _ = generate_with_cache(
426
+ model, tokenizer, q_ids, full_kv, gen_tokens)
427
+
428
+ # --- compressed cache turn ---
429
+ with torch.no_grad():
430
+ cq = model(q_ids, past_key_values=comp_kv, use_cache=True)
431
+ comp_kv = cq.past_key_values
432
+ # re-compress after each turn (cache grew)
433
+ comp_time = compress_cache(comp_kv, compressor, budget,
434
+ cfg["n_layers"])
435
+ comp_txt, _, _ = generate_with_cache(
436
+ model, tokenizer, q_ids, comp_kv, gen_tokens)
437
+
438
+ # --- token match ---
439
+ f_ids = tokenizer.encode(full_txt, add_special_tokens=False)
440
+ c_ids = tokenizer.encode(comp_txt, add_special_tokens=False)
441
+ ml = min(len(f_ids), len(c_ids))
442
+ tmatch = (sum(a == b for a, b in zip(f_ids[:ml], c_ids[:ml]))
443
+ / max(ml, 1)) * 100
444
+
445
+ # --- AOP at this turn (last-layer hidden diff on probe) ---
446
+ probe_text = " The key insight is that wavelet decomposition captures"
447
+ probe = tokenizer(probe_text, return_tensors="pt",
448
+ add_special_tokens=False).to(device)
449
+ with torch.no_grad():
450
+ hf = model(**probe, past_key_values=full_kv,
451
+ output_hidden_states=True, use_cache=False)
452
+ hc = model(**probe, past_key_values=comp_kv,
453
+ output_hidden_states=True, use_cache=False)
454
+ aop = (hf.hidden_states[-1].float()
455
+ - hc.hidden_states[-1].float()).abs().mean().item()
456
+
457
+ full_seq = full_kv.layers[0].keys.shape[2]
458
+ comp_seq = comp_kv.layers[0].keys.shape[2]
459
+
460
+ turns.append({
461
+ "turn": turn_i + 1,
462
+ "question": q[:60],
463
+ "token_match_pct": tmatch,
464
+ "aop": aop,
465
+ "full_cache_seq": full_seq,
466
+ "comp_cache_seq": comp_seq,
467
+ "compress_ms": comp_time,
468
+ "full_sample": full_txt[:120],
469
+ "comp_sample": comp_txt[:120],
470
+ })
471
+
472
+ return turns
473
+
474
+
475
+ # ═══════════════════════ Main runner ══════════════════════════════
476
+
477
+ def build_methods(budget, head_dim):
478
+ return {
479
+ "FourierKV": FourierKV(budget=budget),
480
+ "WaveletKV": WaveletKV(budget=budget, levels=5),
481
+ "WaveletFourierKV": WaveletFourierKV(budget=budget, levels=5,
482
+ cascaded=True),
483
+ "WaveletTriAttn": WaveletTriAttention(budget=budget,
484
+ head_dim=head_dim),
485
+ "TriAttentionKV": TriAttentionKV(budget=budget,
486
+ head_dim=head_dim),
487
+ "TurboQuant-4bit": TurboQuantKV(bits=4, budget=budget),
488
+ }
489
+
490
+
491
+ def main():
492
+ device = "cuda" if torch.cuda.is_available() else "cpu"
493
+ print("="*70)
494
+ print(" SpectralKV v3 β€” Comprehensive Evaluation Suite")
495
+ print("="*70)
496
+ if device == "cuda":
497
+ print(f" GPU: {torch.cuda.get_device_name(0)}")
498
+
499
+ model, tokenizer, cfg = load_model(device=device)
500
+ head_dim = cfg["head_dim"]
501
+ budget = 512
502
+ long_text = get_long_text(target_tokens=8192)
503
+
504
+ results = {"config": cfg, "budget": budget}
505
+
506
+ # ────────── Phase 1: Large-document cache metrics ──────────
507
+ print(f"\n{'━'*70}")
508
+ print(" PHASE 1 Large-document KV cache compression (β‰ˆ8 k tokens)")
509
+ print(f"{'━'*70}")
510
+
511
+ gen_results = {}
512
+ methods = build_methods(budget, head_dim)
513
+ for name, comp in methods.items():
514
+ _clear()
515
+ print(f"\n β–Έ {name}")
516
+ try:
517
+ g = compare_generation(model, tokenizer, long_text, comp,
518
+ budget, cfg, device=device,
519
+ max_len=4096, gen_tokens=64)
520
+ gen_results[name] = g
521
+ print(f" cache {g['cache_full_mb']:.1f} β†’ "
522
+ f"{g['cache_comp_mb']:.1f} MB "
523
+ f"({g['cache_ratio']:.1f}Γ—, {g['cache_saved_pct']:.0f}% saved)")
524
+ print(f" tok/s {g['full_tok_s']:.1f} β†’ {g['comp_tok_s']:.1f} "
525
+ f"({g['speedup']:.2f}Γ— speedup)")
526
+ print(f" match {g['token_match_pct']:.1f}% "
527
+ f"compress {g['compress_time_ms']:.1f} ms")
528
+ except Exception as e:
529
+ print(f" ERROR: {e}")
530
+ import traceback; traceback.print_exc()
531
+
532
+ results["generation"] = gen_results
533
+
534
+ # ────────── Phase 2: AOP (attention output perturbation) ──────
535
+ print(f"\n{'━'*70}")
536
+ print(" PHASE 2 Attention Output Perturbation (AOP / TAFT)")
537
+ print(f"{'━'*70}")
538
+
539
+ aop_results = {}
540
+ methods = build_methods(budget, head_dim)
541
+ for name, comp in methods.items():
542
+ _clear()
543
+ try:
544
+ a = compute_aop(model, tokenizer, long_text, comp,
545
+ budget, cfg, device=device, max_len=4096)
546
+ aop_results[name] = {
547
+ "aop_mean": a["aop_mean"],
548
+ "aop_max": a["aop_max"],
549
+ "compress_ms": a["compress_time_ms"],
550
+ }
551
+ print(f" {name:22s} AOP_mean={a['aop_mean']:.6f} "
552
+ f"AOP_max={a['aop_max']:.6f} "
553
+ f"compress={a['compress_time_ms']:.1f}ms")
554
+ except Exception as e:
555
+ print(f" {name:22s} ERROR: {e}")
556
+
557
+ results["aop"] = aop_results
558
+
559
+ # ────────── Phase 3: Perplexity ──────
560
+ print(f"\n{'━'*70}")
561
+ print(" PHASE 3 Perplexity (ND-PPL)")
562
+ print(f"{'━'*70}")
563
+
564
+ ppl_results = {}
565
+ methods = build_methods(budget, head_dim)
566
+ for name, comp in methods.items():
567
+ _clear()
568
+ try:
569
+ p = measure_ppl(model, tokenizer, long_text, comp,
570
+ budget, cfg, device=device, max_len=4096)
571
+ ppl_results[name] = p
572
+ print(f" {name:22s} PPL_full={p['ppl_full']:.2f} "
573
+ f"PPL_comp={p['ppl_comp']:.2f} "
574
+ f"ND-PPL={p['nd_ppl']:+.4f}")
575
+ except Exception as e:
576
+ print(f" {name:22s} ERROR: {e}")
577
+
578
+ results["perplexity"] = ppl_results
579
+
580
+ # ────────── Phase 4: Multi-turn cache drift ──────
581
+ print(f"\n{'━'*70}")
582
+ print(" PHASE 4 Multi-turn cache drift (5 turns)")
583
+ print(f"{'━'*70}")
584
+
585
+ drift_results = {}
586
+ methods = build_methods(budget, head_dim)
587
+ for name, comp in methods.items():
588
+ _clear()
589
+ print(f"\n β–Έ {name}")
590
+ try:
591
+ turns = multi_turn_drift(
592
+ model, tokenizer, long_text, MULTI_TURN_QUESTIONS,
593
+ comp, budget, cfg, device=device, max_len=4096,
594
+ gen_tokens=48)
595
+ drift_results[name] = turns
596
+ for t in turns:
597
+ print(f" Turn {t['turn']} match={t['token_match_pct']:5.1f}% "
598
+ f"AOP={t['aop']:.6f} "
599
+ f"cache={t['comp_cache_seq']}/{t['full_cache_seq']}")
600
+ except Exception as e:
601
+ print(f" ERROR: {e}")
602
+ import traceback; traceback.print_exc()
603
+
604
+ results["multi_turn_drift"] = drift_results
605
+
606
+ # ────────── Summary tables ──────
607
+ print(f"\n{'━'*70}")
608
+ print(" SUMMARY")
609
+ print(f"{'━'*70}")
610
+
611
+ # --- Generation summary ---
612
+ headers = ["Method", "Cache MB", "Ratio", "Saved%",
613
+ "Tok/s", "Speedup", "Match%", "Comp ms"]
614
+ rows = []
615
+ for name, g in gen_results.items():
616
+ rows.append([
617
+ name,
618
+ f"{g['cache_comp_mb']:.1f}",
619
+ f"{g['cache_ratio']:.1f}Γ—",
620
+ f"{g['cache_saved_pct']:.0f}%",
621
+ f"{g['comp_tok_s']:.1f}",
622
+ f"{g['speedup']:.2f}Γ—",
623
+ f"{g['token_match_pct']:.1f}%",
624
+ f"{g['compress_time_ms']:.1f}",
625
+ ])
626
+ print("\n Generation quality on large document:")
627
+ print(tabulate(rows, headers=headers, tablefmt="grid"))
628
+
629
+ # --- AOP summary ---
630
+ headers = ["Method", "AOP_mean", "AOP_max", "Comp ms"]
631
+ rows = []
632
+ for name, a in aop_results.items():
633
+ rows.append([name, f"{a['aop_mean']:.6f}",
634
+ f"{a['aop_max']:.6f}", f"{a['compress_ms']:.1f}"])
635
+ print("\n Attention Output Perturbation (TAFT):")
636
+ print(tabulate(rows, headers=headers, tablefmt="grid"))
637
+
638
+ # --- PPL summary ---
639
+ headers = ["Method", "PPL_full", "PPL_comp", "ND-PPL"]
640
+ rows = []
641
+ for name, p in ppl_results.items():
642
+ rows.append([name, f"{p['ppl_full']:.2f}",
643
+ f"{p['ppl_comp']:.2f}", f"{p['nd_ppl']:+.4f}"])
644
+ print("\n Perplexity:")
645
+ print(tabulate(rows, headers=headers, tablefmt="grid"))
646
+
647
+ # --- Drift summary ---
648
+ if drift_results:
649
+ headers = ["Method", "T1 Match%", "T3 Match%", "T5 Match%",
650
+ "T1 AOP", "T5 AOP", "Drift(T5-T1)"]
651
+ rows = []
652
+ for name, turns in drift_results.items():
653
+ if len(turns) >= 5:
654
+ rows.append([
655
+ name,
656
+ f"{turns[0]['token_match_pct']:.1f}",
657
+ f"{turns[2]['token_match_pct']:.1f}",
658
+ f"{turns[4]['token_match_pct']:.1f}",
659
+ f"{turns[0]['aop']:.6f}",
660
+ f"{turns[4]['aop']:.6f}",
661
+ f"{turns[4]['aop'] - turns[0]['aop']:+.6f}",
662
+ ])
663
+ print("\n Multi-turn cache drift:")
664
+ print(tabulate(rows, headers=headers, tablefmt="grid"))
665
+
666
+ # --- Save ---
667
+ os.makedirs("/app/results", exist_ok=True)
668
+ with open("/app/results/eval_v3.json", "w") as f:
669
+ json.dump(results, f, indent=2, default=str)
670
+ print(f"\n Results saved to /app/results/eval_v3.json")
671
+
672
+ del model
673
+ _clear()
674
+
675
+
676
+ if __name__ == "__main__":
677
+ main()