diff --git a/nbs/022_tslearner.ipynb b/nbs/022_tslearner.ipynb index 1d3af33c..f0520bd1 100644 --- a/nbs/022_tslearner.ipynb +++ b/nbs/022_tslearner.ipynb @@ -265,10 +265,10 @@ " \n", " \n", " 0\n", - " 1.523830\n", - " 0.266667\n", - " 1.407878\n", - " 0.300000\n", + " 1.464314\n", + " 0.233333\n", + " 1.400173\n", + " 0.166667\n", " 00:00\n", " \n", " \n", @@ -339,8 +339,8 @@ " \n", " \n", " 0\n", - " 1.563760\n", - " 0.166667\n", + " 1.438072\n", + " 0.200000\n", " 00:00\n", " \n", " \n", @@ -600,10 +600,10 @@ " \n", " \n", " 0\n", - " 209.704529\n", - " 13.806342\n", - " 207.336456\n", - " 13.982669\n", + " 221.817291\n", + " 14.270400\n", + " 209.151230\n", + " 14.046944\n", " 00:01\n", " \n", " \n", @@ -860,10 +860,10 @@ " \n", " \n", " 0\n", - " 4226.109375\n", - " 49.230492\n", - " 8007.046387\n", - " 74.881180\n", + " 4114.624023\n", + " 48.891418\n", + " 7991.095703\n", + " 74.791130\n", " 00:00\n", " \n", " \n", @@ -932,9 +932,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "/Users/nacho/notebooks/tsai/nbs/022_tslearner.ipynb saved at 2024-02-11 00:40:14\n", + "/Users/nacho/notebooks/tsai/nbs/022_tslearner.ipynb saved at 2024-02-11 10:55:07\n", "Correct notebook to script conversion! 😃\n", - "Sunday 11/02/24 00:40:17 CET\n" + "Sunday 11/02/24 10:55:10 CET\n" ] }, { diff --git a/nbs/076_models.MultiRocketPlus.ipynb b/nbs/076_models.MultiRocketPlus.ipynb index 895dc26b..fe07d6d5 100644 --- a/nbs/076_models.MultiRocketPlus.ipynb +++ b/nbs/076_models.MultiRocketPlus.ipynb @@ -69,24 +69,22 @@ "source": [ "#| export\n", "def _LPVV(o, dim=2):\n", - " \"Longest stretch of positive values (-1, 1)\"\n", + " \"Longest stretch of positive values along a dimension(-1, 1)\"\n", "\n", " seq_len = o.shape[dim]\n", - "\n", - " # Convert tensor to binary format (1 for positive values, 0 for non-positive values)\n", " binary_tensor = (o > 0).float()\n", "\n", - " # Find the changes in the binary tensor\n", " diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)),\n", - " binary_tensor.narrow(dim, 1, binary_tensor.shape[dim]-1) - binary_tensor.narrow(dim, 0, binary_tensor.shape[dim]-1)], dim=dim)\n", + " binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim)\n", "\n", - " # Create groups of positive values\n", " groups = (diff > 0).cumsum(dim)\n", "\n", - " # Count the number of values in each group\n", - " counts = torch.zeros_like(binary_tensor).scatter_add_(dim, groups * binary_tensor.long(), binary_tensor)\n", + " # Ensure groups are within valid index bounds\n", + " groups = groups * binary_tensor.long()\n", + " valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device))\n", + "\n", + " counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor)\n", "\n", - " # The longest stretch of positive values is the maximum count\n", " longest_stretch = counts.max(dim)[0]\n", "\n", " return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1)\n", @@ -118,6 +116,15 @@ " return (o_pos).float().mean(dim) * 2 - 1" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from tsai.imports import default_device" + ] + }, { "cell_type": "code", "execution_count": null, @@ -127,54 +134,54 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[[-0.0924, 0.0842, 0.5685, 0.3900],\n", - " [ 0.2364, 0.3018, -0.0449, 0.2081],\n", - " [ 0.6782, 0.1842, 0.6873, -0.0590],\n", - " [ 0.1263, 0.2636, 0.3605, -0.0281],\n", - " [ 0.5618, 0.3535, 0.5403, -0.1791]],\n", + "tensor([[[[ 0.5644, -0.0509, -0.0390, 0.4091],\n", + " [ 0.0517, -0.1471, 0.6458, 0.5593],\n", + " [ 0.4516, -0.0821, 0.1271, 0.0592],\n", + " [ 0.4151, 0.4376, 0.0763, 0.3780],\n", + " [ 0.2653, -0.1817, 0.0156, 0.4993]],\n", "\n", - " [[ 0.2201, 0.1868, 0.1791, -0.1343],\n", - " [ 0.3556, -0.1194, -0.2201, 0.4859],\n", - " [ 0.1115, 0.6232, 0.4436, 0.3880],\n", - " [ 0.6350, 0.1362, 0.5869, -0.1968],\n", - " [ 0.0876, 0.4583, 0.0266, 0.3174]],\n", + " [[-0.0779, 0.0858, 0.1982, 0.3224],\n", + " [ 0.1130, 0.0714, -0.1779, 0.5360],\n", + " [-0.1848, -0.2270, -0.0925, -0.1217],\n", + " [ 0.2820, -0.0205, -0.2777, 0.3755],\n", + " [-0.2490, 0.2613, 0.4237, 0.4534]],\n", "\n", - " [[-0.1895, 0.1921, 0.2437, -0.1854],\n", - " [-0.1534, -0.2986, 0.2977, 0.3019],\n", - " [ 0.4613, 0.4243, 0.0115, 0.2684],\n", - " [-0.0923, 0.2066, 0.4980, 0.6450],\n", - " [-0.0348, -0.0297, 0.5451, 0.1900]]],\n", + " [[-0.0162, 0.6368, 0.0016, 0.1467],\n", + " [ 0.6035, -0.1365, 0.6930, 0.6943],\n", + " [ 0.2790, 0.3818, -0.0731, 0.0167],\n", + " [ 0.6442, 0.3443, 0.4829, -0.0944],\n", + " [ 0.2932, 0.6952, 0.5541, 0.5946]]],\n", "\n", "\n", - " [[[ 0.0524, 0.3093, -0.1079, 0.6815],\n", - " [-0.0642, -0.1675, -0.0548, -0.2654],\n", - " [ 0.3172, 0.2939, -0.2412, -0.0502],\n", - " [ 0.1145, -0.0048, 0.0118, 0.1329],\n", - " [ 0.1715, 0.0915, -0.0179, 0.1825]],\n", + " [[[ 0.6757, 0.5740, 0.3071, 0.4400],\n", + " [-0.2344, -0.1056, 0.4773, 0.2432],\n", + " [ 0.2595, -0.1528, -0.0866, 0.6201],\n", + " [ 0.0657, 0.1220, 0.4849, 0.4254],\n", + " [ 0.3399, -0.1609, 0.3465, 0.2389]],\n", "\n", - " [[ 0.3505, 0.1599, 0.4867, 0.0462],\n", - " [-0.1878, 0.2045, 0.0392, -0.0331],\n", - " [-0.2096, 0.6557, 0.6754, 0.4057],\n", - " [ 0.6317, 0.1402, -0.2868, 0.2319],\n", - " [-0.1239, -0.2330, 0.4047, 0.0263]],\n", + " [[-0.0765, 0.0516, 0.0028, 0.4381],\n", + " [ 0.5212, -0.2781, -0.0896, -0.0301],\n", + " [ 0.6857, 0.3583, 0.5869, 0.3418],\n", + " [ 0.3002, 0.5135, 0.6011, 0.6499],\n", + " [-0.2807, -0.2888, 0.3965, 0.6585]],\n", "\n", - " [[ 0.3576, 0.6521, 0.6509, 0.0302],\n", - " [ 0.6389, 0.3282, 0.6566, 0.3341],\n", - " [-0.0629, -0.1169, 0.0781, 0.2252],\n", - " [ 0.4982, 0.2185, 0.4328, 0.5555],\n", - " [ 0.3052, 0.0192, 0.6695, -0.2008]]]])\n", - "tensor([[[ 0.6000, 1.0000, 0.2000, -0.2000],\n", - " [ 1.0000, 0.2000, 0.2000, -0.2000],\n", - " [-0.6000, -0.2000, 1.0000, 0.6000]],\n", + " [[-0.1368, 0.6677, 0.1439, 0.1434],\n", + " [-0.1820, 0.1041, -0.1211, 0.6103],\n", + " [ 0.5808, 0.4588, 0.4572, 0.3713],\n", + " [ 0.2389, -0.1392, 0.1371, -0.1570],\n", + " [ 0.2840, 0.1214, -0.0059, 0.5064]]]], device='mps:0')\n", + "tensor([[[ 1.0000, -0.6000, 0.6000, 1.0000],\n", + " [-0.6000, -0.2000, -0.6000, -0.2000],\n", + " [ 0.6000, 0.2000, -0.2000, 0.2000]],\n", "\n", - " [[ 0.2000, -0.6000, -0.6000, -0.2000],\n", - " [-0.6000, 0.6000, 0.2000, 0.2000],\n", - " [-0.2000, -0.2000, 1.0000, 0.6000]]])\n" + " [[ 0.2000, -0.6000, -0.2000, 1.0000],\n", + " [ 0.2000, -0.2000, 0.2000, 0.2000],\n", + " [ 0.2000, 0.2000, -0.2000, 0.2000]]], device='mps:0')\n" ] } ], "source": [ - "o = torch.rand(2, 3, 5, 4) - .3\n", + "o = torch.rand(2, 3, 5, 4).to(default_device()) - .3\n", "print(o)\n", "\n", "output = _LPVV(o, dim=2)\n", @@ -190,13 +197,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[0.4007, 0.2374, 0.5392, 0.2991],\n", - " [0.2820, 0.3511, 0.3091, 0.3971],\n", - " [0.4613, 0.2744, 0.3192, 0.3513]],\n", + "tensor([[[0.3496, 0.4376, 0.2162, 0.3810],\n", + " [0.1975, 0.1395, 0.3109, 0.4218],\n", + " [0.4550, 0.5145, 0.4329, 0.3631]],\n", "\n", - " [[0.1639, 0.2316, 0.0118, 0.3323],\n", - " [0.4911, 0.2901, 0.4015, 0.1775],\n", - " [0.4500, 0.3045, 0.4976, 0.2862]]])\n" + " [[0.3352, 0.3480, 0.4040, 0.3935],\n", + " [0.5023, 0.3078, 0.3968, 0.5221],\n", + " [0.3679, 0.3380, 0.2460, 0.4079]]], device='mps:0')\n" ] } ], @@ -214,13 +221,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[ 0.8910, 1.0000, 0.9592, 0.3842],\n", - " [ 1.0000, 0.8432, 0.6978, 0.5650],\n", - " [-0.0094, 0.4297, 1.0000, 0.7668]],\n", + "tensor([[[ 1.0000, -0.0270, 0.9138, 1.0000],\n", + " [-0.1286, 0.2568, 0.0630, 0.8654],\n", + " [ 0.9823, 0.8756, 0.9190, 0.8779]],\n", "\n", - " [[ 0.8217, 0.6025, -0.9458, 0.5190],\n", - " [ 0.3065, 0.6655, 0.6970, 0.9109],\n", - " [ 0.9325, 0.8248, 1.0000, 0.7015]]])\n" + " [[ 0.7024, 0.2482, 0.8983, 1.0000],\n", + " [ 0.6168, 0.2392, 0.8931, 0.9715],\n", + " [ 0.5517, 0.8133, 0.7065, 0.8244]]], device='mps:0')\n" ] } ], @@ -238,13 +245,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[-0.3959, -0.5251, -0.1553, -0.8672],\n", - " [-0.4361, -0.4860, -0.5935, -0.6560],\n", - " [-1.0035, -0.8021, -0.3616, -0.5121]],\n", + "tensor([[[-0.3007, -1.0097, -0.6697, -0.2381],\n", + " [-1.0466, -0.9316, -0.9705, -0.3738],\n", + " [-0.2786, -0.2314, -0.3366, -0.4569]],\n", "\n", - " [[-0.7634, -0.7910, -1.1640, -0.7275],\n", - " [-0.8157, -0.6291, -0.4723, -0.7292],\n", - " [-0.3052, -0.5596, -0.0048, -0.6224]]])\n" + " [[-0.5574, -0.8893, -0.3883, -0.2130],\n", + " [-0.5401, -0.8574, -0.4009, -0.1767],\n", + " [-0.6861, -0.5149, -0.7555, -0.4102]]], device='mps:0')\n" ] } ], @@ -614,9 +621,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 01:26:06\n", + "/Users/nacho/notebooks/tsai/nbs/076_models.MultiRocketPlus.ipynb saved at 2024-02-11 10:53:13\n", "Correct notebook to script conversion! 😃\n", - "Sunday 11/02/24 01:26:09 CET\n" + "Sunday 11/02/24 10:53:16 CET\n" ] }, { diff --git a/nbs/models/test.pth b/nbs/models/test.pth index 2b022d10..19a2967b 100644 Binary files a/nbs/models/test.pth and b/nbs/models/test.pth differ diff --git a/tsai/models/MultiRocketPlus.py b/tsai/models/MultiRocketPlus.py index eab80a94..e0adc21a 100644 --- a/tsai/models/MultiRocketPlus.py +++ b/tsai/models/MultiRocketPlus.py @@ -19,24 +19,22 @@ def forward(self, x): return x.view(x.size(0), -1) # %% ../../nbs/076_models.MultiRocketPlus.ipynb 5 def _LPVV(o, dim=2): - "Longest stretch of positive values (-1, 1)" + "Longest stretch of positive values along a dimension(-1, 1)" seq_len = o.shape[dim] - - # Convert tensor to binary format (1 for positive values, 0 for non-positive values) binary_tensor = (o > 0).float() - # Find the changes in the binary tensor diff = torch.cat([torch.ones_like(binary_tensor.narrow(dim, 0, 1)), - binary_tensor.narrow(dim, 1, binary_tensor.shape[dim]-1) - binary_tensor.narrow(dim, 0, binary_tensor.shape[dim]-1)], dim=dim) + binary_tensor.narrow(dim, 1, seq_len-1) - binary_tensor.narrow(dim, 0, seq_len-1)], dim=dim) - # Create groups of positive values groups = (diff > 0).cumsum(dim) - # Count the number of values in each group - counts = torch.zeros_like(binary_tensor).scatter_add_(dim, groups * binary_tensor.long(), binary_tensor) + # Ensure groups are within valid index bounds + groups = groups * binary_tensor.long() + valid_groups = groups.where(groups < binary_tensor.size(dim), torch.tensor(0, device=groups.device)) + + counts = torch.zeros_like(binary_tensor).scatter_add_(dim, valid_groups, binary_tensor) - # The longest stretch of positive values is the maximum count longest_stretch = counts.max(dim)[0] return torch.nan_to_num(2 * (longest_stretch / seq_len) - 1) @@ -67,7 +65,7 @@ def _PPV(o_pos, dim=2): "Proportion of Positive Values (-1, 1)" return (o_pos).float().mean(dim) * 2 - 1 -# %% ../../nbs/076_models.MultiRocketPlus.ipynb 10 +# %% ../../nbs/076_models.MultiRocketPlus.ipynb 11 class MultiRocketFeaturesPlus(nn.Module): fitting = False @@ -239,7 +237,7 @@ def get_indices(self, kernel_size, max_num_kernels): len(indices), max_num_kernels, False))] return indices, pos_values -# %% ../../nbs/076_models.MultiRocketPlus.ipynb 11 +# %% ../../nbs/076_models.MultiRocketPlus.ipynb 12 class MultiRocketBackbonePlus(nn.Module): def __init__(self, c_in, seq_len, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84, use_diff=True): super(MultiRocketBackbonePlus, self).__init__() @@ -266,7 +264,7 @@ def forward(self, x): output = self.branch_x(x) return output -# %% ../../nbs/076_models.MultiRocketPlus.ipynb 12 +# %% ../../nbs/076_models.MultiRocketPlus.ipynb 13 class MultiRocketPlus(nn.Sequential): def __init__(self, c_in, c_out, seq_len, d=None, num_features=50_000, max_dilations_per_kernel=32, kernel_size=9, max_num_channels=None, max_num_kernels=84,