|
|
|
|
|
import argparse |
|
import sys |
|
import os |
|
from train_agent import train_agent |
|
from test_agent import TestAgent, run_test_session |
|
from lightbulb import main as world_model_main |
|
from lightbulb_inf import main as inference_main |
|
from twisted.internet import reactor, task |
|
|
|
def parse_main_args(): |
|
parser = argparse.ArgumentParser(description="Main Menu for Selecting Tasks") |
|
parser.add_argument('--task', type=str, choices=[ |
|
'train_llm_world', |
|
'train_agent', |
|
'test_agent', |
|
'inference_llm', |
|
'inference_world_model', |
|
'advanced_inference' |
|
], |
|
required=True, |
|
help='Choose task to execute: train_llm_world, train_agent, test_agent, inference_llm, inference_world_model, advanced_inference') |
|
|
|
parser.add_argument('--model_name', type=str, default='gpt2', help='Pretrained model name for LLM') |
|
parser.add_argument('--dataset_name', type=str, default='wikitext', help='Dataset name for training') |
|
parser.add_argument('--dataset_config', type=str, default='wikitext-2-raw-v1', help='Dataset configuration name') |
|
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training') |
|
parser.add_argument('--num_epochs', type=int, default=3, help='Number of epochs for training') |
|
parser.add_argument('--max_length', type=int, default=128, help='Maximum sequence length for training') |
|
parser.add_argument('--mode', type=str, choices=['train', 'inference'], default='train', help='Train or inference mode for LLM') |
|
parser.add_argument('--query', type=str, default='', help='Query for the test_agent or inference tasks') |
|
|
|
return parser.parse_args() |
|
|
|
def main(): |
|
|
|
args = parse_main_args() |
|
|
|
|
|
if args.task == 'train_llm_world': |
|
print("Starting LLM and World Model Training...") |
|
|
|
sys.argv = [ |
|
'lightbulb_custom.py', |
|
'--mode', args.mode, |
|
'--model_name', args.model_name, |
|
'--dataset_name', args.dataset_name, |
|
'--dataset_config', args.dataset_config, |
|
'--batch_size', str(args.batch_size), |
|
'--num_epochs', str(args.num_epochs), |
|
'--max_length', str(args.max_length) |
|
] |
|
world_model_main() |
|
|
|
elif args.task == 'train_agent': |
|
print("Starting Agent Training...") |
|
|
|
d = task.deferLater(reactor, 0, train_agent) |
|
d.addErrback(lambda failure: print(f"An error occurred: {failure}", exc_info=True)) |
|
d.addBoth(lambda _: reactor.stop()) |
|
reactor.run() |
|
|
|
elif args.task == 'test_agent': |
|
print("Starting Test Agent...") |
|
test_agent = TestAgent() |
|
if args.query: |
|
|
|
result = test_agent.process_query(args.query) |
|
print("\nAgent's response:") |
|
print(result) |
|
else: |
|
|
|
reactor.callWhenRunning(run_test_session) |
|
reactor.run() |
|
|
|
elif args.task in ['inference_llm', 'inference_world_model', 'advanced_inference']: |
|
print("Starting Inference Task...") |
|
|
|
|
|
|
|
inference_mode_map = { |
|
'inference_llm': 'without_world_model', |
|
'inference_world_model': 'world_model', |
|
'advanced_inference': 'world_model_tree_of_thought' |
|
} |
|
|
|
selected_inference_mode = inference_mode_map.get(args.task, 'world_model_tree_of_thought') |
|
|
|
|
|
lightbulb_inf_args = [ |
|
'lightbulb_custom.py', |
|
'--mode', 'inference', |
|
'--model_name', args.model_name, |
|
'--query', args.query, |
|
'--max_length', str(args.max_length), |
|
'--inference_mode', selected_inference_mode, |
|
'--beam_size', str(getattr(args, 'beam_size', 5)), |
|
'--n_tokens_predict', str(getattr(args, 'n_tokens_predict', 3)), |
|
'--mcts_iterations', str(getattr(args, 'mcts_iterations', 10)), |
|
'--mcts_exploration_constant', str(getattr(args, 'mcts_exploration_constant', 1.414)) |
|
] |
|
|
|
|
|
if hasattr(args, 'load_model') and args.load_model: |
|
lightbulb_inf_args += ['--load_model', args.load_model] |
|
|
|
|
|
sys.argv = lightbulb_inf_args |
|
inference_main() |
|
|
|
else: |
|
print(f"Unknown task: {args.task}") |
|
sys.exit(1) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|