runthebandsup commited on
Commit
bc1fc68
·
verified ·
1 Parent(s): 23aa755

Fix checkpoint loading to extract model_state_dict properly

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -18,7 +18,11 @@ try:
18
 
19
  num_classes = len(label_to_class)
20
  model = FineGrainedClassifier(num_classes=num_classes)
21
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
 
 
 
22
  model.eval()
23
 
24
  # Load text tokenizer
 
18
 
19
  num_classes = len(label_to_class)
20
  model = FineGrainedClassifier(num_classes=num_classes)
21
+
22
+ # Load checkpoint and extract model_state_dict
23
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
24
+ state_dict = checkpoint.get('model_state_dict', checkpoint)
25
+ model.load_state_dict(state_dict)
26
  model.eval()
27
 
28
  # Load text tokenizer