Adapters
English
code
medical
File size: 1,231 Bytes
8be16c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
import torch.nn as nn
from models.gating_network import GatingNetwork
from models.vision_expert import VisionExpert
from models.audio_expert import AudioExpert
from models.sensor_expert import SensorExpert

class MoEModel(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(MoEModel, self).__init__()
        self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts)
        self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()])
        self.fc_final = nn.Linear(128, 10)  # Assuming 10 possible actions

    def forward(self, vision_input, audio_input, sensor_input):
        vision_features = self.experts[0](vision_input)
        audio_features = self.experts[1](audio_input)
        sensor_features = self.experts[2](sensor_input)
        
        combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1)
        gating_weights = self.gating_network(combined_features)
        
        expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1)
        final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs)
        
        return self.fc_final(final_output)