-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
38 lines (32 loc) · 1.31 KB
/
test.py
File metadata and controls
38 lines (32 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn # ← add this
import numpy as np
import pickle
from keras.preprocessing.sequence import pad_sequences
# ---- load artifacts ----
tok = pickle.load(open('models/tokenizer.pickle','rb'))
emb = np.load('models/embedding_matrix.npy')
state= torch.load('models/genre_classifier.pth', map_location='cpu')
# ---- model definition ----
class GenreLSTM(nn.Module):
def __init__(self, emb):
super().__init__()
v, e = emb.shape
self.embedding = nn.Embedding(v, e) # name matches checkpoint
self.embedding.weight = nn.Parameter(torch.tensor(emb, dtype=torch.float32),
requires_grad=False)
self.lstm = nn.LSTM(e, 128, batch_first=True, bidirectional=True)
self.fc = nn.Linear(256, 10)
def forward(self, x):
out, _ = self.lstm(self.embedding(x))
return self.fc(out.mean(1))
# ---- load weights & test ----
model = GenreLSTM(emb)
model.load_state_dict(state) # should load cleanly now
model.eval()
demo = "A detective must solve a brutal murder in a futuristic city."
seq = tok.texts_to_sequences([demo])
pad = pad_sequences(seq, maxlen=200)
with torch.no_grad():
print(torch.sigmoid(model(torch.tensor(pad))).shape)
print("✅ model runs")