# main_menu.py import argparse import sys import os from train_agent import train_agent from test_agent import TestAgent, run_test_session from lightbulb import main as world_model_main from lightbulb_inf import main as inference_main # Import the inference main function from twisted.internet import reactor, task def parse_main_args(): parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") parser.add_argument('--task', type=str, choices=[ 'train_llm_world', 'train_agent', 'test_agent', 'inference_llm', 'inference_world_model', 'advanced_inference' ], required=True, help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference') # Optional arguments for more granular control parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM') parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks') # Additional arguments specific to inference can be added here if needed return parser.parse_args() def main(): # Parse arguments for the main function args = parse_main_args() # Execute tasks based on user input if args.task == 'train_llm_world': print("Starting LLM and World Model Training...") # Directly call the world model main function with appropriate arguments sys.argv = [ 'lightbulb.py', '--mode', args.mode, '--model_name', args.model_name, '--dataset_name', args.dataset_name, '--dataset_config', args.dataset_config, '--batch_size', str(args.batch_size), '--num_epochs', str(args.num_epochs), '--max_length', str(args.max_length) ] world_model_main() elif args.task == 'train_agent': print("Starting Agent Training...") # Call the train_agent function from train_agent.py using Twisted reactor d = task.deferLater(reactor, 0, train_agent) d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) d.addBoth(lambda _: reactor.stop()) reactor.run() elif args.task == 'test_agent': print("Starting Test Agent...") test_agent = TestAgent() if args.query: # Directly process a single query result = test_agent.process_query(args.query) print("\nAgent's response:") print(result) else: # Run the interactive session reactor.callWhenRunning(run_test_session) reactor.run() elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']: print("Starting Inference Task...") # Prepare the arguments for lightbulb_inf.py based on the selected inference task # Map the main_menu task to lightbulb_inf.py's inference_mode inference_mode_map = { 'inference_llm': 'without_world_model', 'inference_world_model': 'world_model', 'advanced_inference': 'world_model_tree_of_thought' } selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought') # Construct sys.argv for lightbulb_inf.py lightbulb_inf_args = [ 'lightbulb_custom.py', '--mode', 'inference', '--model_name', args.model_name, '--query', args.query, '--max_length', str(args.max_length), '--inference_mode', selected_inference_mode, '--beam_size', str(getattr(args, 'beam_size', 5)), '--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)), '--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)), '--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414)) ] # Include additional arguments if they exist if hasattr(args, 'load_model') and args.load_model: lightbulb_inf_args += ['--load_model', args.load_model] # Update sys.argv and call the inference main function sys.argv = lightbulb_inf_args inference_main() else: print(f"Unknown task: {args.task}") sys.exit(1) if __name__ == "__main__": main()