# main_menu.py import argparse import sys import os from train_agent import train_agent from test_agent import TestAgent, run_test_session from twisted.internet import reactor, task from lightbulb_custom import main as lightbulb_custom_main from distillation_pipeline import distill_model # Import the distillation function from transformers import logging # Suppress transformers warnings for cleaner output logging.set_verbosity_error() def parse_main_args(): parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") # Task selection parser.add_argument('--task', type=str, choices=[ 'train_llm_world', 'train_agent', 'test_agent', 'inference_llm', 'inference_world_model', 'advanced_inference', 'distill_full_model', # New option for full model distillation 'distill_domain_specific' # New option for selective distillation ], required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference, distill_full_model, distill_domain_specific') # Common arguments parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') parser.add_argument('--student_model_name', type=str, default='distilgpt2', help='Name of the student model for distillation') parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') parser.add_argument('--temperature', type=float, default=2.0, help='Distillation temperature') parser.add_argument('--learning_rate', type=float, default=5e-5, help='Learning rate') # Distillation-specific arguments parser.add_argument('--save_path', type=str, default="./distilled_model", help="Path to save the distilled model") parser.add_argument('--log_dir', type=str, default="./logs", help="Directory for TensorBoard logs") parser.add_argument('--checkpoint_dir', type=str, default="./checkpoints", help="Directory to save checkpoints") parser.add_argument('--early_stopping_patience', type=int, default=3, help="Early stopping patience") # Inference-specific arguments parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks') parser.add_argument('--inference_mode', type=str, choices=['without_world_model', 'world_model', 'world_model_tree_of_thought'], help='Inference mode') parser.add_argument('--beam_size', type=int, default=5, help='Beam size for beam search during inference') parser.add_argument('--n_tokens_predict', type=int, default=3, help='Number of tokens to predict at each step during inference') parser.add_argument('--mcts_iterations', type=int, default=10, help='Number of MCTS iterations during inference') parser.add_argument('--mcts_exploration_constant', type=float, default=1.414, help='Exploration constant for MCTS during inference') # Distillation-specific arguments parser.add_argument('--distill_full_model', action="store_true", help="Whether to distill the full model or not") parser.add_argument('--query_terms', type=str, nargs="+", help="Query terms for domain-specific distillation") # Load model for inference parser.add_argument('--load_model', type=str, help='Path to load the distilled model for inference') return parser.parse_args() def main(): # Parse arguments for the main function args = parse_main_args() # Execute tasks based on user input if args.task == 'train_llm_world': print("Starting LLM and World Model Training...") # Directly call the world model main function with appropriate arguments sys.argv = [ 'lightbulb_custom.py', '--mode', 'train', '--model_name', args.model_name, '--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config, '--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs), '--max_length', str(args.max_length) ] lightbulb_custom_main() elif args.task == 'train_agent': print("Starting Agent Training...") # Call the train_agent function from train_agent.py using Twisted reactor d = task.deferLater(reactor, 0, train_agent) d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) d.addBoth(lambda _: reactor.stop()) reactor.run() elif args.task == 'test_agent': print("Starting Test Agent...") test_agent = TestAgent() if args.query: # Directly process a single query result = test_agent.process_query(args.query) print("\nAgent's response:") print(result) else: # Run the interactive session reactor.callWhenRunning(run_test_session) reactor.run() elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']: print("Starting Inference Task...") # Prepare the arguments for lightbulb_custom.py based on the selected inference task # Map the main_menu task to lightbulb_custom.py's inference_mode inference_mode_map = { 'inference_llm': 'without_world_model', 'inference_world_model': 'world_model', 'advanced_inference': 'world_model_tree_of_thought' } selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought') # Construct sys.argv for lightbulb_custom.py lightbulb_inf_args = [ 'lightbulb_custom.py', '--mode', 'inference', '--model_name', args.model_name, '--query', args.query, '--max_length', str(args.max_length), '--inference_mode', selected_inference_mode, '--beam_size', str(args.beam_size), '--n_tokens_predict', str(args.n_tokens_predict), '--mcts_iterations', str(args.mcts_iterations), '--mcts_exploration_constant', str(args.mcts_exploration_constant) ] # Include additional arguments if they exist if args.load_model: lightbulb_inf_args += ['--load_model', args.load_model] # Update sys.argv and call the inference main function sys.argv = lightbulb_inf_args lightbulb_custom_main() elif args.task == 'distill_full_model': print("Starting Full Model Distillation...") distill_model( teacher_model_name=args.model_name, student_model_name=args.student_model_name, dataset_name=args.dataset_name, config=args.dataset_config, distill_full_model=True, query_terms=None, num_epochs=args.num_epochs, batch_size=args.batch_size, max_length=args.max_length, learning_rate=args.learning_rate, temperature=args.temperature, save_path=args.save_path, log_dir=args.log_dir, checkpoint_dir=args.checkpoint_dir, early_stopping_patience=args.early_stopping_patience ) elif args.task == 'distill_domain_specific': print("Starting Domain-Specific Distillation...") if not args.query_terms: print("Error: --query_terms must be provided for domain-specific distillation.") sys.exit(1) distill_model( teacher_model_name=args.model_name, student_model_name=args.student_model_name, dataset_name=args.dataset_name, config=args.dataset_config, distill_full_model=False, query_terms=args.query_terms, num_epochs=args.num_epochs, batch_size=args.batch_size, max_length=args.max_length, learning_rate=args.learning_rate, temperature=args.temperature, save_path=args.save_path, log_dir=args.log_dir, checkpoint_dir=args.checkpoint_dir, early_stopping_patience=args.early_stopping_patience ) else: print(f"Unknown task: {args.task}") sys.exit(1) if __name__ == "__main__": main()