diff --git a/nbs/022_tslearner.ipynb b/nbs/022_tslearner.ipynb
index 1d3af33c..f0520bd1 100644
--- a/nbs/022_tslearner.ipynb
+++ b/nbs/022_tslearner.ipynb
@@ -265,10 +265,10 @@
"
\n",
" \n",
" 0 | \n",
- " 1.523830 | \n",
- " 0.266667 | \n",
- " 1.407878 | \n",
- " 0.300000 | \n",
+ " 1.464314 | \n",
+ " 0.233333 | \n",
+ " 1.400173 | \n",
+ " 0.166667 | \n",
" 00:00 | \n",
"
\n",
" \n",
@@ -339,8 +339,8 @@
" \n",
" \n",
" 0 | \n",
- " 1.563760 | \n",
- " 0.166667 | \n",
+ " 1.438072 | \n",
+ " 0.200000 | \n",
" 00:00 | \n",
"
\n",
" \n",
@@ -600,10 +600,10 @@
" \n",
" \n",
" 0 | \n",
- " 209.704529 | \n",
- " 13.806342 | \n",
- " 207.336456 | \n",
- " 13.982669 | \n",
+ " 221.817291 | \n",
+ " 14.270400 | \n",
+ " 209.151230 | \n",
+ " 14.046944 | \n",
" 00:01 | \n",
"
\n",
" \n",
@@ -860,10 +860,10 @@
" \n",
" \n",
" 0 | \n",
- " 4226.109375 | \n",
- " 49.230492 | \n",
- " 8007.046387 | \n",
- " 74.881180 | \n",
+ " 4114.624023 | \n",
+ " 48.891418 | \n",
+ " 7991.095703 | \n",
+ " 74.791130 | \n",
" 00:00 | \n",
"
\n",
" \n",
@@ -932,9 +932,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "/Users/nacho/notebooks/tsai/nbs/022_tslearner.ipynb saved at 2024-02-11 00:40:14\n",
+ "/Users/nacho/notebooks/tsai/nbs/022_tslearner.ipynb saved at 2024-02-11 10:55:07\n",
"Correct notebook to script conversion! 😃\n",
- "Sunday 11/02/24 00:40:17 CET\n"
+ "Sunday 11/02/24 10:55:10 CET\n"
]
},
{
diff --git a/nbs/076_models.MultiRocketPlus.ipynb b/nbs/076_models.MultiRocketPlus.ipynb
index 895dc26b..fe07d6d5 100644
--- a/nbs/076_models.MultiRocketPlus.ipynb
+++ b/nbs/076_models.MultiRocketPlus.ipynb
@@ -69,24 +69,22 @@
"source": [
"#| export\n",
"def _LPVV(o, dim=2):\n",
- " \"Longest stretch of positive values (-1, 1)\"\n",
+ " \"Longest stretch of positive values along a dimension(-1, 1)\"\n",
"\n",
" seq_len = o.shape[dim]\n",
- "\n",
- " # Convert tensor to binary format (1 for positive values, 0 for non-positive values)\n",
" binary_tensor = (o > 0).float()\n",
"\n",
- " # Find the changes in the binary tensor\n",
" diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),\n",
- " binary_tensor.narrow(dim, 1, binary_tensor.shape[dim]-1) - binary_tensor.narrow(dim, 0, binary_tensor.shape[dim]-1)], dim=dim)\n",
+ " binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)\n",
"\n",
- " # Create groups of positive values\n",
" groups = (diff > 0).cumsum(dim)\n",
"\n",
- " # Count the number of values in each group\n",
- " counts = torch.zeros_like(binary_tensor).scatter_add_(dim, groups * binary_tensor.long(), binary_tensor)\n",
+ " # Ensure groups are within valid index bounds\n",
+ " groups = groups * binary_tensor.long()\n",
+ " valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))\n",
+ "\n",
+ " counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)\n",
"\n",
- " # The longest stretch of positive values is the maximum count\n",
" longest_stretch = counts.max(dim)[0]\n",
"\n",
" return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)\n",
@@ -118,6 +116,15 @@
" return (o_pos).float().mean(dim) * 2 - 1"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from tsai.imports import default_device"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -127,54 +134,54 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "tensor([[[[-0.0924, 0.0842, 0.5685, 0.3900],\n",
- " [ 0.2364, 0.3018, -0.0449, 0.2081],\n",
- " [ 0.6782, 0.1842, 0.6873, -0.0590],\n",
- " [ 0.1263, 0.2636, 0.3605, -0.0281],\n",
- " [ 0.5618, 0.3535, 0.5403, -0.1791]],\n",
+ "tensor([[[[ 0.5644, -0.0509, -0.0390, 0.4091],\n",
+ " [ 0.0517, -0.1471, 0.6458, 0.5593],\n",
+ " [ 0.4516, -0.0821, 0.1271, 0.0592],\n",
+ " [ 0.4151, 0.4376, 0.0763, 0.3780],\n",
+ " [ 0.2653, -0.1817, 0.0156, 0.4993]],\n",
"\n",
- " [[ 0.2201, 0.1868, 0.1791, -0.1343],\n",
- " [ 0.3556, -0.1194, -0.2201, 0.4859],\n",
- " [ 0.1115, 0.6232, 0.4436, 0.3880],\n",
- " [ 0.6350, 0.1362, 0.5869, -0.1968],\n",
- " [ 0.0876, 0.4583, 0.0266, 0.3174]],\n",
+ " [[-0.0779, 0.0858, 0.1982, 0.3224],\n",
+ " [ 0.1130, 0.0714, -0.1779, 0.5360],\n",
+ " [-0.1848, -0.2270, -0.0925, -0.1217],\n",
+ " [ 0.2820, -0.0205, -0.2777, 0.3755],\n",
+ " [-0.2490, 0.2613, 0.4237, 0.4534]],\n",
"\n",
- " [[-0.1895, 0.1921, 0.2437, -0.1854],\n",
- " [-0.1534, -0.2986, 0.2977, 0.3019],\n",
- " [ 0.4613, 0.4243, 0.0115, 0.2684],\n",
- " [-0.0923, 0.2066, 0.4980, 0.6450],\n",
- " [-0.0348, -0.0297, 0.5451, 0.1900]]],\n",
+ " [[-0.0162, 0.6368, 0.0016, 0.1467],\n",
+ " [ 0.6035, -0.1365, 0.6930, 0.6943],\n",
+ " [ 0.2790, 0.3818, -0.0731, 0.0167],\n",
+ " [ 0.6442, 0.3443, 0.4829, -0.0944],\n",
+ " [ 0.2932, 0.6952, 0.5541, 0.5946]]],\n",
"\n",
"\n",
- " [[[ 0.0524, 0.3093, -0.1079, 0.6815],\n",
- " [-0.0642, -0.1675, -0.0548, -0.2654],\n",
- " [ 0.3172, 0.2939, -0.2412, -0.0502],\n",
- " [ 0.1145, -0.0048, 0.0118, 0.1329],\n",
- " [ 0.1715, 0.0915, -0.0179, 0.1825]],\n",
+ " [[[ 0.6757, 0.5740, 0.3071, 0.4400],\n",
+ " [-0.2344, -0.1056, 0.4773, 0.2432],\n",
+ " [ 0.2595, -0.1528, -0.0866, 0.6201],\n",
+ " [ 0.0657, 0.1220, 0.4849, 0.4254],\n",
+ " [ 0.3399, -0.1609, 0.3465, 0.2389]],\n",
"\n",
- " [[ 0.3505, 0.1599, 0.4867, 0.0462],\n",
- " [-0.1878, 0.2045, 0.0392, -0.0331],\n",
- " [-0.2096, 0.6557, 0.6754, 0.4057],\n",
- " [ 0.6317, 0.1402, -0.2868, 0.2319],\n",
- " [-0.1239, -0.2330, 0.4047, 0.0263]],\n",
+ " [[-0.0765, 0.0516, 0.0028, 0.4381],\n",
+ " [ 0.5212, -0.2781, -0.0896, -0.0301],\n",
+ " [ 0.6857, 0.3583, 0.5869, 0.3418],\n",
+ " [ 0.3002, 0.5135, 0.6011, 0.6499],\n",
+ " [-0.2807, -0.2888, 0.3965, 0.6585]],\n",
"\n",
- " [[ 0.3576, 0.6521, 0.6509, 0.0302],\n",
- " [ 0.6389, 0.3282, 0.6566, 0.3341],\n",
- " [-0.0629, -0.1169, 0.0781, 0.2252],\n",
- " [ 0.4982, 0.2185, 0.4328, 0.5555],\n",
- " [ 0.3052, 0.0192, 0.6695, -0.2008]]]])\n",
- "tensor([[[ 0.6000, 1.0000, 0.2000, -0.2000],\n",
- " [ 1.0000, 0.2000, 0.2000, -0.2000],\n",
- " [-0.6000, -0.2000, 1.0000, 0.6000]],\n",
+ " [[-0.1368, 0.6677, 0.1439, 0.1434],\n",
+ " [-0.1820, 0.1041, -0.1211, 0.6103],\n",
+ " [ 0.5808, 0.4588, 0.4572, 0.3713],\n",
+ " [ 0.2389, -0.1392, 0.1371, -0.1570],\n",
+ " [ 0.2840, 0.1214, -0.0059, 0.5064]]]], device='mps:0')\n",
+ "tensor([[[ 1.0000, -0.6000, 0.6000, 1.0000],\n",
+ " [-0.6000, -0.2000, -0.6000, -0.2000],\n",
+ " [ 0.6000, 0.2000, -0.2000, 0.2000]],\n",
"\n",
- " [[ 0.2000, -0.6000, -0.6000, -0.2000],\n",
- " [-0.6000, 0.6000, 0.2000, 0.2000],\n",
- " [-0.2000, -0.2000, 1.0000, 0.6000]]])\n"
+ " [[ 0.2000, -0.6000, -0.2000, 1.0000],\n",
+ " [ 0.2000, -0.2000, 0.2000, 0.2000],\n",
+ " [ 0.2000, 0.2000, -0.2000, 0.2000]]], device='mps:0')\n"
]
}
],
"source": [
- "o = torch.rand(2, 3, 5, 4) - .3\n",
+ "o = torch.rand(2, 3, 5, 4).to(default_device()) - .3\n",
"print(o)\n",
"\n",
"output = _LPVV(o, dim=2)\n",
@@ -190,13 +197,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "tensor([[[0.4007, 0.2374, 0.5392, 0.2991],\n",
- " [0.2820, 0.3511, 0.3091, 0.3971],\n",
- " [0.4613, 0.2744, 0.3192, 0.3513]],\n",
+ "tensor([[[0.3496, 0.4376, 0.2162, 0.3810],\n",
+ " [0.1975, 0.1395, 0.3109, 0.4218],\n",
+ " [0.4550, 0.5145, 0.4329, 0.3631]],\n",
"\n",
- " [[0.1639, 0.2316, 0.0118, 0.3323],\n",
- " [0.4911, 0.2901, 0.4015, 0.1775],\n",
- " [0.4500, 0.3045, 0.4976, 0.2862]]])\n"
+ " [[0.3352, 0.3480, 0.4040, 0.3935],\n",
+ " [0.5023, 0.3078, 0.3968, 0.5221],\n",
+ " [0.3679, 0.3380, 0.2460, 0.4079]]], device='mps:0')\n"
]
}
],
@@ -214,13 +221,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "tensor([[[ 0.8910, 1.0000, 0.9592, 0.3842],\n",
- " [ 1.0000, 0.8432, 0.6978, 0.5650],\n",
- " [-0.0094, 0.4297, 1.0000, 0.7668]],\n",
+ "tensor([[[ 1.0000, -0.0270, 0.9138, 1.0000],\n",
+ " [-0.1286, 0.2568, 0.0630, 0.8654],\n",
+ " [ 0.9823, 0.8756, 0.9190, 0.8779]],\n",
"\n",
- " [[ 0.8217, 0.6025, -0.9458, 0.5190],\n",
- " [ 0.3065, 0.6655, 0.6970, 0.9109],\n",
- " [ 0.9325, 0.8248, 1.0000, 0.7015]]])\n"
+ " [[ 0.7024, 0.2482, 0.8983, 1.0000],\n",
+ " [ 0.6168, 0.2392, 0.8931, 0.9715],\n",
+ " [ 0.5517, 0.8133, 0.7065, 0.8244]]], device='mps:0')\n"
]
}
],
@@ -238,13 +245,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "tensor([[[-0.3959, -0.5251, -0.1553, -0.8672],\n",
- " [-0.4361, -0.4860, -0.5935, -0.6560],\n",
- " [-1.0035, -0.8021, -0.3616, -0.5121]],\n",
+ "tensor([[[-0.3007, -1.0097, -0.6697, -0.2381],\n",
+ " [-1.0466, -0.9316, -0.9705, -0.3738],\n",
+ " [-0.2786, -0.2314, -0.3366, -0.4569]],\n",
"\n",
- " [[-0.7634, -0.7910, -1.1640, -0.7275],\n",
- " [-0.8157, -0.6291, -0.4723, -0.7292],\n",
- " [-0.3052, -0.5596, -0.0048, -0.6224]]])\n"
+ " [[-0.5574, -0.8893, -0.3883, -0.2130],\n",
+ " [-0.5401, -0.8574, -0.4009, -0.1767],\n",
+ " [-0.6861, -0.5149, -0.7555, -0.4102]]], device='mps:0')\n"
]
}
],
@@ -614,9 +621,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 01:26:06\n",
+ "/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 10:53:13\n",
"Correct notebook to script conversion! 😃\n",
- "Sunday 11/02/24 01:26:09 CET\n"
+ "Sunday 11/02/24 10:53:16 CET\n"
]
},
{
diff --git a/nbs/models/test.pth b/nbs/models/test.pth
index 2b022d10..19a2967b 100644
Binary files a/nbs/models/test.pth and b/nbs/models/test.pth differ
diff --git a/tsai/models/MultiRocketPlus.py b/tsai/models/MultiRocketPlus.py
index eab80a94..e0adc21a 100644
--- a/tsai/models/MultiRocketPlus.py
+++ b/tsai/models/MultiRocketPlus.py
@@ -19,24 +19,22 @@ def forward(self, x): return x.view(x.size(0), -1)
# %% ../../nbs/076_models.MultiRocketPlus.ipynb 5
def _LPVV(o, dim=2):
- "Longest stretch of positive values (-1, 1)"
+ "Longest stretch of positive values along a dimension(-1, 1)"
seq_len = o.shape[dim]
-
- # Convert tensor to binary format (1 for positive values, 0 for non-positive values)
binary_tensor = (o > 0).float()
- # Find the changes in the binary tensor
diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),
- binary_tensor.narrow(dim, 1, binary_tensor.shape[dim]-1) - binary_tensor.narrow(dim, 0, binary_tensor.shape[dim]-1)], dim=dim)
+ binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)
- # Create groups of positive values
groups = (diff > 0).cumsum(dim)
- # Count the number of values in each group
- counts = torch.zeros_like(binary_tensor).scatter_add_(dim, groups * binary_tensor.long(), binary_tensor)
+ # Ensure groups are within valid index bounds
+ groups = groups * binary_tensor.long()
+ valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))
+
+ counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)
- # The longest stretch of positive values is the maximum count
longest_stretch = counts.max(dim)[0]
return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)
@@ -67,7 +65,7 @@ def _PPV(o_pos, dim=2):
"Proportion of Positive Values (-1, 1)"
return (o_pos).float().mean(dim) * 2 - 1
-# %% ../../nbs/076_models.MultiRocketPlus.ipynb 10
+# %% ../../nbs/076_models.MultiRocketPlus.ipynb 11
class MultiRocketFeaturesPlus(nn.Module):
fitting = False
@@ -239,7 +237,7 @@ def get_indices(self, kernel_size, max_num_kernels):
len(indices), max_num_kernels, False))]
return indices, pos_values
-# %% ../../nbs/076_models.MultiRocketPlus.ipynb 11
+# %% ../../nbs/076_models.MultiRocketPlus.ipynb 12
class MultiRocketBackbonePlus(nn.Module):
def __init__(self, c_in, seq_len, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84, use_diff=True):
super(MultiRocketBackbonePlus, self).__init__()
@@ -266,7 +264,7 @@ def forward(self, x):
output = self.branch_x(x)
return output
-# %% ../../nbs/076_models.MultiRocketPlus.ipynb 12
+# %% ../../nbs/076_models.MultiRocketPlus.ipynb 13
class MultiRocketPlus(nn.Sequential):
def __init__(self, c_in, c_out, seq_len, d=None, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84,