File size: 8,998 Bytes
7f47926 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
# main_menu.py
import argparse
import sys
import os
from train_agent import train_agent
from test_agent import TestAgent, run_test_session
from twisted.internet import reactor, task
from lightbulb_custom import main as lightbulb_custom_main
from distillation_pipeline import distill_model # Import the distillation function
from transformers import logging
# Suppress transformers warnings for cleaner output
logging.set_verbosity_error()
def parse_main_args():
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks")
# Task selection
parser.add_argument('--task', type=str, choices=[
'train_llm_world',
'train_agent',
'test_agent',
'inference_llm',
'inference_world_model',
'advanced_inference',
'distill_full_model', # New option for full model distillation
'distill_domain_specific' # New option for selective distillation
],
required=True,
help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference, distill_full_model, distill_domain_specific')
# Common arguments
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM')
parser.add_argument('--student_model_name', type=str, default='distilgpt2', help='Name of the student model for distillation')
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training')
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training')
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training')
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training')
parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature')
parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate')
# Distillation-specific arguments
parser.add_argument('--save_path', type=str, default="./distilled_model", help="Path to save the distilled model")
parser.add_argument('--log_dir', type=str, default="./logs", help="Directory for TensorBoard logs")
parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help="Directory to save checkpoints")
parser.add_argument('--early_stopping_patience', type=int, default=3, help="Early stopping patience")
# Inference-specific arguments
parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks')
parser.add_argument('--inference_mode', type=str, choices=['without_world_model', 'world_model', 'world_model_tree_of_thought'], help='Inference mode')
parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search during inference')
parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step during inference')
parser.add_argument('--mcts_iterations', type=int, default=10, help='Number of MCTS iterations during inference')
parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS during inference')
# Distillation-specific arguments
parser.add_argument('--distill_full_model', action="store_true", help="Whether to distill the full model or not")
parser.add_argument('--query_terms', type=str, nargs="+", help="Query terms for domain-specific distillation")
# Load model for inference
parser.add_argument('--load_model', type=str, help='Path to load the distilled model for inference')
return parser.parse_args()
def main():
# Parse arguments for the main function
args = parse_main_args()
# Execute tasks based on user input
if args.task == 'train_llm_world':
print("Starting LLM and World Model Training...")
# Directly call the world model main function with appropriate arguments
sys.argv = [
'lightbulb_custom.py',
'--mode', 'train',
'--model_name', args.model_name,
'--dataset_name', args.dataset_name,
'--dataset_config', args.dataset_config,
'--batch_size', str(args.batch_size),
'--num_epochs', str(args.num_epochs),
'--max_length', str(args.max_length)
]
lightbulb_custom_main()
elif args.task == 'train_agent':
print("Starting Agent Training...")
# Call the train_agent function from train_agent.py using Twisted reactor
d = task.deferLater(reactor, 0, train_agent)
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True))
d.addBoth(lambda _: reactor.stop())
reactor.run()
elif args.task == 'test_agent':
print("Starting Test Agent...")
test_agent = TestAgent()
if args.query:
# Directly process a single query
result = test_agent.process_query(args.query)
print("\nAgent's response:")
print(result)
else:
# Run the interactive session
reactor.callWhenRunning(run_test_session)
reactor.run()
elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']:
print("Starting Inference Task...")
# Prepare the arguments for lightbulb_custom.py based on the selected inference task
# Map the main_menu task to lightbulb_custom.py's inference_mode
inference_mode_map = {
'inference_llm': 'without_world_model',
'inference_world_model': 'world_model',
'advanced_inference': 'world_model_tree_of_thought'
}
selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought')
# Construct sys.argv for lightbulb_custom.py
lightbulb_inf_args = [
'lightbulb_custom.py',
'--mode', 'inference',
'--model_name', args.model_name,
'--query', args.query,
'--max_length', str(args.max_length),
'--inference_mode', selected_inference_mode,
'--beam_size', str(args.beam_size),
'--n_tokens_predict', str(args.n_tokens_predict),
'--mcts_iterations', str(args.mcts_iterations),
'--mcts_exploration_constant', str(args.mcts_exploration_constant)
]
# Include additional arguments if they exist
if args.load_model:
lightbulb_inf_args += ['--load_model', args.load_model]
# Update sys.argv and call the inference main function
sys.argv = lightbulb_inf_args
lightbulb_custom_main()
elif args.task == 'distill_full_model':
print("Starting Full Model Distillation...")
distill_model(
teacher_model_name=args.model_name,
student_model_name=args.student_model_name,
dataset_name=args.dataset_name,
config=args.dataset_config,
distill_full_model=True,
query_terms=None,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
max_length=args.max_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
save_path=args.save_path,
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
early_stopping_patience=args.early_stopping_patience
)
elif args.task == 'distill_domain_specific':
print("Starting Domain-Specific Distillation...")
if not args.query_terms:
print("Error: --query_terms must be provided for domain-specific distillation.")
sys.exit(1)
distill_model(
teacher_model_name=args.model_name,
student_model_name=args.student_model_name,
dataset_name=args.dataset_name,
config=args.dataset_config,
distill_full_model=False,
query_terms=args.query_terms,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
max_length=args.max_length,
learning_rate=args.learning_rate,
temperature=args.temperature,
save_path=args.save_path,
log_dir=args.log_dir,
checkpoint_dir=args.checkpoint_dir,
early_stopping_patience=args.early_stopping_patience
)
else:
print(f"Unknown task: {args.task}")
sys.exit(1)
if __name__ == "__main__":
main()
|