Question about why it's compatible
Hi, I am a co-author of Pyramid Flow, and hope to address the PyTorch version restriction as soon as possible. I just saw you comment
Change this line:
self.timesteps_per_stage[i_s] = torch.from_numpy(timesteps[:-1])
To this:
self.timesteps_per_stage[i_s] = timesteps[:-1]
This will allow the model to be compatible with newer versions of pytorch and other libraries than is shown in the requirements.
I wonder what the mechanism behind this is, and whether we should remove all torch.from_numpy
like this.
It seems to be from here:
timesteps = np.linspace(
timestep_max, timestep_min, training_steps + 1,
)
timestep_max
and timestep_min
here are singleton tensors, and apparently np.linspace
, when used with tensors as the endpoints, will return a tensor. I don't know which version of numpy this was added in.
I have one environment with numpy 1.26.4 and it doesn't happen there.
In version 2.0.2, though, it does. For example, this:
a = torch.Tensor([0])
b = torch.Tensor([10])
c = np.linspace(a, b, 11)
print(type(c))
Returns <class 'torch.Tensor'>.
It seems to be specific to this use of torch.from_numpy
because the inputs happened to already be tensors.
Maybe this is a specific issue with a newer version of numpy, or really not an issue but a convenience function to keep things as tensors that wasn't expected to happen here.
Here is an alternative which should work for both versions:
timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)].item()
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)].item()
This will convert the singleton tensor into a number, which then will behave normally with np.linspace
.
Not sure about other incompatibilities, but the rest of the code works with torch 2.4.1+cu124 other than this small change. Tested with just installing that and the requirements file with == changed to >= for each requirement.
Thanks for the explanation, I will try it and see if it works in my environment.
Looks like PR #39 in the inference repo resolved it in another alternative way :)