winglian commited on
Commit
501958b
1 Parent(s): c25ba79

create a model card with axolotl badge (#624)

Browse files
Files changed (1) hide show
  1. src/axolotl/train.py +7 -2
src/axolotl/train.py CHANGED
@@ -9,8 +9,7 @@ from pathlib import Path
9
  from typing import Optional
10
 
11
  import torch
12
-
13
- # add src to the pythonpath so we don't need to pip install this
14
  from datasets import Dataset
15
  from optimum.bettertransformer import BetterTransformer
16
 
@@ -103,6 +102,9 @@ def train(
103
  signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
104
  )
105
 
 
 
 
106
  LOG.info("Starting trainer...")
107
  if cfg.group_by_length:
108
  LOG.info("hang tight... sorting dataset for group_by_length")
@@ -138,4 +140,7 @@ def train(
138
 
139
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
140
 
 
 
 
141
  return model, tokenizer
 
9
  from typing import Optional
10
 
11
  import torch
12
+ import transformers.modelcard
 
13
  from datasets import Dataset
14
  from optimum.bettertransformer import BetterTransformer
15
 
 
102
  signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
103
  )
104
 
105
+ badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
106
+ transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
107
+
108
  LOG.info("Starting trainer...")
109
  if cfg.group_by_length:
110
  LOG.info("hang tight... sorting dataset for group_by_length")
 
140
 
141
  model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
142
 
143
+ if not cfg.hub_model_id:
144
+ trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
145
+
146
  return model, tokenizer