import os
import matplotlib.pyplot as plt
import numpy as np
import requests
import scipy.io
import torch as tc
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
def download_data():
files = [
("regularized_model_final.pth", "https://osf.io/kc7sb/download"),
("unregularized_model_final.pth", "https://osf.io/9vsy5/download"),
("condsForSimJ2moMuscles.mat", "https://osf.io/wak7e/download"),
("m1_reaching_data.mat", "https://osf.io/p2x4n/download"),
]
os.mkdir("data")
for name, url in files:
with open(f"data/{name}", "wb+") as f:
f.write(requests.get(url).content)
def tensor(x):
return tc.tensor(np.array(x), dtype=tc.float32)
def normalize(x: tc.Tensor):
a, b = x.min(), x.max()
x = (x - a) / (b - a)
return (x - x.mean()) / x.std()
def load_data(path):
data = scipy.io.loadmat(path)
d: np.ndarray = data["condsForSim"]
def get(k):
I, J = d.shape
return tensor([[d[i, j][k] for j in range(J)] for i in range(I)])
go, plan, mus = map(get, ("goEnvelope", "plan", "muscle"))
plan = normalize(plan)
inp = normalize(tc.cat([plan, go], dim=3))
a, b, c = 3, slice(46, 296), [3, 4]
return inp[:, a, b, :], mus[:, a, b, c]
def plot_data(inp: tc.Tensor, mus: tc.Tensor, hid: tc.Tensor = None):
cond = [0, 1, 2]
R = 2 if hid is None else 3
C = len(cond)
plt.figure(figsize=(4 * C, 3 * R))
for j, c in enumerate(cond):
N = inp.shape[2]
plt.subplot(R, C, j + 1)
plt.title("input (15 plan + 1 go)")
for i in range(N):
a = 5 if i == N - 1 else 1
plt.plot(a * inp[c, :, i])
plt.subplot(R, C, j + 1 + C)
plt.title("output (muscle)")
for i in range(mus.shape[2]):
plt.plot(mus[c, :, i])
if hid is not None:
plt.subplot(R, C, j + 1 + 2 * C)
plt.title("hidden")
for i in range(hid.shape[2]):
plt.plot(hid[c, :, i])
plt.tight_layout()
plt.savefig("W1D1")
plt.close()
# =================================
def make_loaders(x, y, B=20):
class MyData(Dataset):
def __init__(s, x, y):
s.x, s.y = x, y
def __len__(s):
return len(s.x)
def __getitem__(s, i):
return s.x[i], s.y[i]
i1, i2 = random_split(range(27), [20, 7])
d1, d2 = MyData(x[i1], y[i1]), MyData(x[i2], y[i2])
l1 = DataLoader(d1, B, shuffle=True)
l2 = DataLoader(d2, B, shuffle=False)
return l1, l2
# =================================
def rectified_tanh(x):
return tc.where(x > 0, tc.tanh(x), 0)
def mm(A, B: tc.Tensor):
return tc.matmul(A, B.transpose(0, 1)).transpose(0, 1)
def copy(x: tc.Tensor):
return x.detach().clone()
class MyRNN(nn.Module):
def __init__(s, I, H, O, g, h, tt=5):
super().__init__()
s.H, s.tt = H, tt
s.out = nn.Linear(H, O)
s.J = nn.Parameter(tc.randn(H, H) * (g / np.sqrt(H)))
s.B = nn.Parameter(tc.randn(H, I) * (h / np.sqrt(I)))
s.b = nn.Parameter(tc.zeros(H))
s.act = rectified_tanh
def init_h(s, B):
return tc.zeros(B, s.H)
def forward(s, x: tc.Tensor, h):
h += (1 / s.tt) * (-h + mm(s.B, x) + mm(s.J, s.act(h)) + s.b)
fr = s.act(h)
fr_reg = fr.pow(2).sum()
return s.out(fr), h, fr_reg
def run(s, x: tc.Tensor):
B, T = x.shape[:2]
h = s.init_h(B)
os, hs = [], []
fr_reg = 0
for t in range(T):
o, h, fr_reg_t = s(x[:, t], h)
fr_reg += fr_reg_t / T
os.append(o)
hs.append(copy(h))
l2_reg = sum(p.pow(2).sum() for p in s.parameters())
os, hs = tc.stack(os, dim=1), tc.stack(hs, dim=1)
return os, hs, fr_reg, l2_reg
# ===============================
def train(
rnn: MyRNN, l1: DataLoader, l2: DataLoader, lr=1e-3, epochs=10000, l2w=0, frw=0
):
opt = tc.optim.Adam(rnn.parameters(), lr=lr)
loss_fn = nn.MSELoss()
losses = []
for e in range(epochs):
rnn.train()
opt.zero_grad()
train_loss: tc.Tensor = 0
for x, y in l1:
yp, hs, fr_reg, l2_reg = rnn.run(x)
loss: tc.Tensor = loss_fn(yp, y) + l2w * l2_reg + frw * fr_reg
train_loss += loss / len(l1)
train_loss.backward()
opt.step()
rnn.eval()
val_loss = 0
with tc.no_grad():
for x, y in l2:
yp2, hs2, fr_reg, l2_reg = rnn.run(x)
val_loss += loss_fn(yp2, y).item() / len(l2)
losses.append([train_loss.item(), val_loss])
if e % 100 == 0 and e:
ys = np.array(losses).T
plt.plot(np.log10(ys[0]), label="log10(train loss)")
plt.plot(np.log10(ys[1]), label="log10(val loss)")
plt.legend()
plt.savefig("train")
plt.close()
plot_data(x.numpy(), yp.detach().numpy(), hs.detach().numpy())
if __name__ == "__main__":
# tc.manual_seed(0)
# np.random.seed(0)
# download_data()
path = "data/condsForSimJ2moMuscles.mat"
inp, mus = load_data(path)
print(inp.shape, mus.shape)
# plot_data(inp, mus)
l1, l2 = make_loaders(inp, mus)
rnn = MyRNN(16, 10, 2, 4, 1)
if 0:
for x, y in l1:
yp, hs, fr_reg, l2_reg = rnn.run(x)
print(hs.shape)
plot_data(inp=x, mus=yp, hid=hs)
break
train(rnn, l1, l2)
# https://neuroai.neuromatch.io/tutorials/W1D1_Generalization/student/W1D1_Tutorial2.html#coding-exercise-2-2-evaluate-function
could you provide the training code for https://github.com/neuromatch/NeuroAI_Course/blob/main/tutorials/W1D1_Generalization/instructor/W1D1_Tutorial2.ipynb
below is my attempt, something is not quite right