Text Generation
Transformers
PyTorch
English
gpt2
feature-extraction
causal-lm
text-generation-inference

Cerebras GPT 13B Underperforms Scaling Law Predictions?

#11
by RylanSchaeffer - opened

I'm trying to study scaling behavior of language models during pretraining, and I thought Cerebras would be perfect because it was designed for this! However, I've discovered that 13B seems to underperform the power law scaling one would expect:

image.png

To produce this data, I'm looping over the models and datasets and using "standard" transformers logic:

        tokenizer = AutoTokenizer.from_pretrained(
            huggingface_path, trust_remote_code=True
        )
        model = AutoModelForCausalLM.from_pretrained(
            huggingface_path,
            device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch.float32,
        )

        log_probs_dict = {
            "log_probs": [],
            "Problem Idx": [],
            "Seq Idx": [],
        }
        for sample_idx, sequence in enumerate(sequences[:1000]):
            # for sample_idx, sequence in enumerate(sequences[:10]):
            encoded_sequence = tokenizer(
                sequence,
                return_tensors="pt",
                add_special_tokens=False,
                truncation=True,
                max_length=max_context_length,
            ).to(model.device)

            with torch.no_grad():
                input_ids = encoded_sequence.input_ids
                labels = input_ids.clone()
                labels = labels[:, 1:]  # Remove first position

                output = model(**encoded_sequence)

                logits = output.logits
                logits = logits[
                    :, :-1, :
                ]  # Remove last position as it has nothing to predict
                # Apply log softmax over vocabulary dimension
                log_probs = torch.log_softmax(logits, dim=-1)
                # Gather log probs for the actual tokens in the sequence.
                # Shape: (sequence_length,)
                token_log_probs = torch.gather(
                    log_probs, 2, labels.unsqueeze(-1)
                ).squeeze()

Why might 13B be underperforming? Is there something obvious I'm missing? Disclaimer: I am using model parallelism.

Sign up or log in to comment