Commit
•
7ab54e2
1
Parent(s):
48bf805
Update README for PyTorchModelHubMixin (#4)
Browse files- Update README for PyTorchModelHubMixin (f42c49a9cbca3ec6b51173f78779a846a28dca94)
Co-authored-by: Gerald Woo <gorold@users.noreply.huggingface.co>
README.md
CHANGED
@@ -45,13 +45,13 @@ A simple example to get started:
|
|
45 |
|
46 |
```python
|
47 |
import torch
|
|
|
48 |
import pandas as pd
|
49 |
from gluonts.dataset.pandas import PandasDataset
|
50 |
from gluonts.dataset.split import split
|
51 |
-
from huggingface_hub import hf_hub_download
|
52 |
|
53 |
from uni2ts.eval_util.plot import plot_single
|
54 |
-
from uni2ts.model.moirai import MoiraiForecast
|
55 |
|
56 |
|
57 |
SIZE = "small" # model size: choose from {'small', 'base', 'large'}
|
@@ -85,9 +85,7 @@ test_data = test_template.generate_instances(
|
|
85 |
|
86 |
# Prepare pre-trained model by downloading model weights from huggingface hub
|
87 |
model = MoiraiForecast.load_from_checkpoint(
|
88 |
-
|
89 |
-
repo_id=f"Salesforce/moirai-R-{SIZE}", filename="model.ckpt"
|
90 |
-
),
|
91 |
prediction_length=PDT,
|
92 |
context_length=CTX,
|
93 |
patch_size=PSZ,
|
@@ -95,7 +93,6 @@ model = MoiraiForecast.load_from_checkpoint(
|
|
95 |
target_dim=1,
|
96 |
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
|
97 |
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
|
98 |
-
map_location="cuda:0" if torch.cuda.is_available() else "cpu",
|
99 |
)
|
100 |
|
101 |
predictor = model.create_predictor(batch_size=BSZ)
|
@@ -117,6 +114,7 @@ plot_single(
|
|
117 |
name="pred",
|
118 |
show_label=True,
|
119 |
)
|
|
|
120 |
```
|
121 |
|
122 |
## The Moirai Family
|
|
|
45 |
|
46 |
```python
|
47 |
import torch
|
48 |
+
import matplotlib.pyplot as plt
|
49 |
import pandas as pd
|
50 |
from gluonts.dataset.pandas import PandasDataset
|
51 |
from gluonts.dataset.split import split
|
|
|
52 |
|
53 |
from uni2ts.eval_util.plot import plot_single
|
54 |
+
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
|
55 |
|
56 |
|
57 |
SIZE = "small" # model size: choose from {'small', 'base', 'large'}
|
|
|
85 |
|
86 |
# Prepare pre-trained model by downloading model weights from huggingface hub
|
87 |
model = MoiraiForecast.load_from_checkpoint(
|
88 |
+
module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{SIZE}"),
|
|
|
|
|
89 |
prediction_length=PDT,
|
90 |
context_length=CTX,
|
91 |
patch_size=PSZ,
|
|
|
93 |
target_dim=1,
|
94 |
feat_dynamic_real_dim=ds.num_feat_dynamic_real,
|
95 |
past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
|
|
|
96 |
)
|
97 |
|
98 |
predictor = model.create_predictor(batch_size=BSZ)
|
|
|
114 |
name="pred",
|
115 |
show_label=True,
|
116 |
)
|
117 |
+
plt.show()
|
118 |
```
|
119 |
|
120 |
## The Moirai Family
|