Spaces:
Sleeping
Sleeping
Fix checkpoint loading to extract model_state_dict properly
Browse files
app.py
CHANGED
|
@@ -18,7 +18,11 @@ try:
|
|
| 18 |
|
| 19 |
num_classes = len(label_to_class)
|
| 20 |
model = FineGrainedClassifier(num_classes=num_classes)
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
model.eval()
|
| 23 |
|
| 24 |
# Load text tokenizer
|
|
|
|
| 18 |
|
| 19 |
num_classes = len(label_to_class)
|
| 20 |
model = FineGrainedClassifier(num_classes=num_classes)
|
| 21 |
+
|
| 22 |
+
# Load checkpoint and extract model_state_dict
|
| 23 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
| 24 |
+
state_dict = checkpoint.get('model_state_dict', checkpoint)
|
| 25 |
+
model.load_state_dict(state_dict)
|
| 26 |
model.eval()
|
| 27 |
|
| 28 |
# Load text tokenizer
|