|
import os |
|
import stat |
|
import xml.etree.ElementTree as ET |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import logging |
|
import requests |
|
from collections import defaultdict |
|
from typing import List, Dict, Any |
|
from colorama import Fore, Style, init |
|
from accelerate import Accelerator |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from transformers import AutoTokenizer, AutoModel |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
|
|
|
|
init(autoreset=True) |
|
logging.basicConfig(level=logging.INFO, format='\033[92m%(asctime)s - %(levelname)s - %(message)s\033[0m') |
|
|
|
file_path = 'data/' |
|
output_path = 'output/' |
|
|
|
|
|
if not os.path.exists(output_path): |
|
try: |
|
os.makedirs(output_path) |
|
os.chmod(output_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
except PermissionError: |
|
print(f"Permission denied: '{output_path}'") |
|
|
|
|
|
|
|
def ensure_file(file_path): |
|
if not os.path.exists(file_path): |
|
with open(file_path, 'w') as f: |
|
pass |
|
os.chmod(file_path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) |
|
|
|
class MagicStateLayer(nn.Module): |
|
def __init__(self, size): |
|
super().__init__() |
|
self.state = nn.Parameter(torch.randn(size)) |
|
|
|
def forward(self, x): |
|
return x + self.state |
|
|
|
class MemoryAugmentationLayer(nn.Module): |
|
def __init__(self, size): |
|
super().__init__() |
|
self.memory = nn.Parameter(torch.randn(size)) |
|
|
|
def forward(self, x): |
|
return x + self.memory |
|
|
|
class HybridAttentionLayer(nn.Module): |
|
def __init__(self, size): |
|
super().__init__() |
|
self.attention = nn.MultiheadAttention(size, num_heads=8) |
|
|
|
def forward(self, x): |
|
x = x.unsqueeze(1) |
|
attn_output, _ = self.attention(x, x, x) |
|
return attn_output.squeeze(1) |
|
|
|
class DynamicFlashAttentionLayer(nn.Module): |
|
def __init__(self, size): |
|
super().__init__() |
|
self.attention = nn.MultiheadAttention(size, num_heads=8) |
|
|
|
def forward(self, x): |
|
x = x.unsqueeze(1) |
|
attn_output, _ = self.attention(x, x, x) |
|
return attn_output.squeeze(1) |
|
|
|
class DynamicModel(nn.Module): |
|
def __init__(self, sections: Dict[str, List[Dict[str, Any]]]): |
|
super().__init__() |
|
self.sections = nn.ModuleDict({sn: nn.ModuleList([self.create_layer(lp) for lp in layers]) for sn, layers in sections.items()}) |
|
|
|
def create_layer(self, lp): |
|
layers = [nn.Linear(lp['input_size'], lp['output_size'])] |
|
if lp.get('batch_norm', True): |
|
layers.append(nn.BatchNorm1d(lp['output_size'])) |
|
activation = lp.get('activation', 'relu') |
|
if activation == 'relu': |
|
layers.append(nn.ReLU(inplace=True)) |
|
elif activation == 'tanh': |
|
layers.append(nn.Tanh()) |
|
elif activation == 'sigmoid': |
|
layers.append(nn.Sigmoid()) |
|
elif activation == 'leaky_relu': |
|
layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=True)) |
|
elif activation == 'elu': |
|
layers.append(nn.ELU(alpha=1.0, inplace=True)) |
|
if dropout := lp.get('dropout', 0.1): |
|
layers.append(nn.Dropout(p=dropout)) |
|
if lp.get('memory_augmentation', True): |
|
layers.append(MemoryAugmentationLayer(lp['output_size'])) |
|
if lp.get('hybrid_attention', True): |
|
layers.append(HybridAttentionLayer(lp['output_size'])) |
|
if lp.get('dynamic_flash_attention', True): |
|
layers.append(DynamicFlashAttentionLayer(lp['output_size'])) |
|
if lp.get('magic_state', True): |
|
layers.append(MagicStateLayer(lp['output_size'])) |
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x, section_name=None): |
|
if section_name: |
|
for layer in self.sections[section_name]: |
|
x = layer(x) |
|
else: |
|
for section_name, layers in self.sections.items(): |
|
for layer in layers: |
|
x = layer(x) |
|
return x |
|
|
|
def parse_xml_file(file_path): |
|
tree, root, layers = ET.parse(file_path), ET.parse(file_path).getroot(), [] |
|
for layer in root.findall('.//label'): |
|
lp = { |
|
'input_size': int(layer.get('input_size', 128)), |
|
'output_size': int(layer.get('output_size', 256)), |
|
'activation': layer.get('activation', 'relu').lower() |
|
} |
|
if lp['activation'] not in ['relu', 'tanh', 'sigmoid', 'none']: |
|
raise ValueError(f"Unsupported activation function: {lp['activation']}") |
|
if lp['input_size'] <= 0 or lp['output_size'] <= 0: |
|
raise ValueError("Layer dimensions must be positive integers") |
|
layers.append(lp) |
|
if not layers: |
|
layers.append({'input_size': 128, 'output_size': 256, 'activation': 'relu'}) |
|
return layers |
|
|
|
def create_model_from_folder(folder_path): |
|
sections = defaultdict(list) |
|
if not os.path.exists(folder_path): |
|
logging.warning(f"Folder {folder_path} does not exist. Creating model with default configuration.") |
|
return DynamicModel({}) |
|
xml_files_found = False |
|
for root, dirs, files in os.walk(folder_path): |
|
for file in files: |
|
if file.endswith('.xml'): |
|
xml_files_found = True |
|
file_path = os.path.join(root, file) |
|
try: |
|
sections[os.path.basename(root).replace('.', '_')].extend(parse_xml_file(file_path)) |
|
except Exception as e: |
|
logging.error(f"Error processing {file_path}: {str(e)}") |
|
if not xml_files_found: |
|
logging.warning("No XML files found. Creating model with default configuration.") |
|
return DynamicModel({}) |
|
return DynamicModel(dict(sections)) |
|
|
|
def create_embeddings_and_stores(folder_path, model_name="sentence-transformers/all-MiniLM-L6-v2"): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
doc_store = [] |
|
embeddings_list = [] |
|
for root, dirs, files in os.walk(folder_path): |
|
for file in files: |
|
if file.endswith('.xml'): |
|
file_path = os.path.join(root, file) |
|
try: |
|
tree, root = ET.parse(file_path), ET.parse(file_path).getroot() |
|
for elem in root.iter(): |
|
if elem.text: |
|
text = elem.text.strip() |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() |
|
embeddings_list.append(embeddings) |
|
doc_store.append(text) |
|
except Exception as e: |
|
logging.error(f"Error processing {file_path}: {str(e)}") |
|
return embeddings_list, doc_store |
|
|
|
def query_embeddings(query, embeddings_list, doc_store, model_name="sentence-transformers/all-MiniLM-L6-v2"): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModel.from_pretrained(model_name) |
|
inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True) |
|
with torch.no_grad(): |
|
query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() |
|
similarities = [np.dot(query_embedding, emb.T) for emb in embeddings_list] |
|
top_k_indices = np.argsort(similarities, axis=0)[-5:][::-1] |
|
return [doc_store[i] for i in top_k_indices] |
|
|
|
def fetch_courtlistener_data(query): |
|
base_url = "https://nzlii.org/cgi-bin/sinosrch.cgi" |
|
params = {"method": "auto", "query": query, "meta": "/nz", "results": "50", "format": "json"} |
|
try: |
|
response = requests.get(base_url, params=params, headers={"Accept": "application/json"}, timeout=10) |
|
response.raise_for_status() |
|
return [{"title": r.get("title", ""), "citation": r.get("citation", ""), "date": r.get("date", ""), "court": r.get("court", ""), "summary": r.get("summary", ""), "url": r.get("url", "")} for r in response.json().get("results", [])] |
|
except requests.exceptions.RequestException as e: |
|
logging.error(f"Failed to fetch data from NZLII API: {str(e)}") |
|
return [] |
|
|
|
class CustomModel(nn.Module): |
|
def __init__(self, model_name="distilbert-base-uncased"): |
|
super().__init__() |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.encoder = AutoModel.from_pretrained(model_name) |
|
self.hidden_size = self.encoder.config.hidden_size |
|
self.dropout = nn.Dropout(p=0.3) |
|
self.fc1 = nn.Linear(self.hidden_size, 128) |
|
self.fc2 = nn.Linear(128, 64) |
|
self.fc3 = nn.Linear(64, 32) |
|
self.fc4 = nn.Linear(32, 16) |
|
self.memory = nn.LSTM(self.hidden_size, 64, bidirectional=True, batch_first=True) |
|
self.memory_fc1 = nn.Linear(64 * 2, 32) |
|
self.memory_fc2 = nn.Linear(32, 16) |
|
|
|
def forward(self, data): |
|
tokens = self.tokenizer(data, return_tensors="pt", truncation=True, padding=True) |
|
outputs = self.encoder(**tokens) |
|
x = outputs.last_hidden_state.mean(dim=1) |
|
x = self.dropout(F.relu(self.fc1(x))) |
|
x = self.dropout(F.relu(self.fc2(x))) |
|
x = self.dropout(F.relu(self.fc3(x))) |
|
x = self.fc4(x) |
|
return x |
|
|
|
def training_step(self, data, labels, optimizer, criterion): |
|
optimizer.zero_grad() |
|
outputs = self.forward(data) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
return loss.item() |
|
|
|
def validation_step(self, data, labels, criterion): |
|
with torch.no_grad(): |
|
outputs = self.forward(data) |
|
loss = criterion(outputs, labels) |
|
return loss.item() |
|
|
|
def predict(self, input): |
|
self.eval() |
|
with torch.no_grad(): |
|
return self.forward(input) |
|
|
|
def main(): |
|
folder_path = 'data' |
|
model = create_model_from_folder(folder_path) |
|
logging.info(f"Created dynamic PyTorch model with sections: {list(model.sections.keys())}") |
|
embeddings_list, doc_store = create_embeddings_and_stores(folder_path) |
|
accelerator = Accelerator() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
|
criterion = nn.CrossEntropyLoss() |
|
num_epochs = 10 |
|
dataset = TensorDataset(torch.randn(100, 128), torch.randint(0, 2, (100,))) |
|
dataloader = DataLoader(dataset, batch_size=16, shuffle=True) |
|
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) |
|
for epoch in range(num_epochs): |
|
model.train() |
|
total_loss = 0 |
|
for batch_data, batch_labels in dataloader: |
|
optimizer.zero_grad() |
|
outputs = model(batch_data) |
|
loss = criterion(outputs, batch_labels) |
|
accelerator.backward(loss) |
|
optimizer.step() |
|
total_loss += loss.item() |
|
avg_loss = total_loss / len(dataloader) |
|
logging.info(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}") |
|
query = "example query text" |
|
results = query_embeddings(query, embeddings_list, doc_store) |
|
logging.info(f"Query results: {results}") |
|
courtlistener_data = fetch_courtlistener_data(query) |
|
logging.info(f"CourtListener API results: {courtlistener_data}") |
|
|
|
if __name__ == "__main__": |
|
main() |