Adapters
English
code
medical
UANN / Models /MoE_model.py
dnnsdunca's picture
Create Models/MoE_model.py
8be16c0 verified
raw
history blame
1.23 kB
import torch
import torch.nn as nn
from models.gating_network import GatingNetwork
from models.vision_expert import VisionExpert
from models.audio_expert import AudioExpert
from models.sensor_expert import SensorExpert
class MoEModel(nn.Module):
def __init__(self, input_dim, num_experts):
super(MoEModel, self).__init__()
self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts)
self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()])
self.fc_final = nn.Linear(128, 10) # Assuming 10 possible actions
def forward(self, vision_input, audio_input, sensor_input):
vision_features = self.experts[0](vision_input)
audio_features = self.experts[1](audio_input)
sensor_features = self.experts[2](sensor_input)
combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1)
gating_weights = self.gating_network(combined_features)
expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1)
final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs)
return self.fc_final(final_output)