File size: 1,606 Bytes
1bae57f 2a350ba dc78d43 2a350ba 1bae57f 2a350ba dc78d43 2a350ba 1bae57f dc78d43 1bae57f dc78d43 2ec93d0 dc78d43 2a350ba 2ec93d0 2a350ba 80b293f 2a350ba 82a267c 1bae57f 2ec93d0 2a350ba 1bae57f 2ec93d0 82a267c 2a350ba 82a267c 2a350ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig, CausalLMOutput
# Define the model configuration
class HelloWorldConfig(PretrainedConfig):
model_type = "hello-world"
vocab_size = 2
bos_token_id = 0
eos_token_id = 1
# Define the model
class HelloWorldModel(PreTrainedModel):
config_class = HelloWorldConfig
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids=None, **kwargs):
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]
# Generate logits for the "Hello, world!" token
hello_world_token_id = self.config.vocab_size - 1
logits = torch.full((batch_size, sequence_length, self.config.vocab_size), float('-inf'))
logits[:, :, hello_world_token_id] = 0
return CausalLMOutput(logits=logits)
# Define and save the tokenizer
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
tokenizer.add_tokens(["Hello, world!"])
tokenizer_config = {
"do_lower_case": False,
"model_max_length": 512,
"padding_side": "right",
"special_tokens_map_file": None,
"tokenizer_file": "tokenizer.json",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"vocab_size": 2,
}
with open("tokenizer.json", "w") as f:
json.dump(tokenizer_config, f)
# Initialize model
config = HelloWorldConfig()
model = HelloWorldModel(config)
# Save model using safetensors format
from safetensors.torch import save_file
save_file(model.state_dict(), "hello_world_model.safetensors")
|