miittnnss commited on
Commit
0deb359
1 Parent(s): d1a03bc

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -1
pipeline.py CHANGED
@@ -1,9 +1,29 @@
1
  import torch
2
  from PIL import Image
3
  from torchvision import transforms
 
4
 
5
  class Generator(nn.Module):
6
- # The Generator class remains the same as provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class PreTrainedPipeline():
9
  def __init__(self, path=""):
 
1
  import torch
2
  from PIL import Image
3
  from torchvision import transforms
4
+ import torch.nn as nn
5
 
6
  class Generator(nn.Module):
7
+ def __init__(self, input_size, output_channels):
8
+ super(Generator, self).__init__()
9
+
10
+ # Define the architecture of the generator
11
+ self.model = nn.Sequential(
12
+ nn.Linear(input_size, 128), # Input layer
13
+ nn.LeakyReLU(0.2), # Activation function
14
+ nn.Linear(128, 256), # Hidden layer
15
+ nn.BatchNorm1d(256), # Batch normalization
16
+ nn.LeakyReLU(0.2), # Activation function
17
+ nn.Linear(256, 512), # Hidden layer
18
+ nn.BatchNorm1d(512), # Batch normalization
19
+ nn.LeakyReLU(0.2), # Activation function
20
+ nn.Linear(512, output_channels), # Output layer
21
+ nn.Tanh() # Tanh activation for output
22
+ )
23
+
24
+ def forward(self, x):
25
+ # Forward pass through the generator network
26
+ return self.model(x)
27
 
28
  class PreTrainedPipeline():
29
  def __init__(self, path=""):