@@ -102,7 +102,7 @@ static void byteswap_tensor(ggml_tensor * tensor) {
102
102
#define WHISPER_PRINT_DEBUG (...)
103
103
#endif
104
104
105
- #define WHISPER_USE_FLASH_ATTN
105
+ // #define WHISPER_USE_FLASH_ATTN
106
106
// #define WHISPER_USE_FLASH_FF
107
107
#define WHISPER_MAX_DECODERS 16
108
108
@@ -224,11 +224,11 @@ static const std::map<std::string, std::pair<int, std::string>> g_lang = {
224
224
static const size_t MB = 1ull *1024 *1024 ;
225
225
226
226
static const std::map<e_model, size_t > MEM_REQ_SCRATCH0 = {
227
- { MODEL_TINY, 14ull *MB },
228
- { MODEL_BASE, 18ull *MB },
229
- { MODEL_SMALL, 28ull *MB },
230
- { MODEL_MEDIUM, 36ull *MB },
231
- { MODEL_LARGE, 44ull *MB },
227
+ { MODEL_TINY, 62ull *MB },
228
+ { MODEL_BASE, 80ull *MB },
229
+ { MODEL_SMALL, 120ull *MB },
230
+ { MODEL_MEDIUM, 158ull *MB },
231
+ { MODEL_LARGE, 198ull *MB },
232
232
};
233
233
234
234
static const std::map<e_model, size_t > MEM_REQ_SCRATCH1 = {
@@ -280,11 +280,11 @@ static const std::map<e_model, size_t> MEM_REQ_KV_CROSS = {
280
280
};
281
281
282
282
static const std::map<e_model, size_t > MEM_REQ_ENCODE = {
283
- { MODEL_TINY, 6ull *MB },
284
- { MODEL_BASE, 8ull *MB },
285
- { MODEL_SMALL, 13ull *MB },
286
- { MODEL_MEDIUM, 22ull *MB },
287
- { MODEL_LARGE, 33ull *MB },
283
+ { MODEL_TINY, 30ull *MB },
284
+ { MODEL_BASE, 38ull *MB },
285
+ { MODEL_SMALL, 56ull *MB },
286
+ { MODEL_MEDIUM, 74ull *MB },
287
+ { MODEL_LARGE, 94ull *MB },
288
288
};
289
289
290
290
static const std::map<e_model, size_t > MEM_REQ_DECODE = {
@@ -1554,26 +1554,17 @@ static bool whisper_encode_internal(
1554
1554
1555
1555
struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctx0, KQ_scaled);
1556
1556
1557
- // struct ggml_tensor * V_trans =
1558
- // ggml_permute(ctx0,
1559
- // ggml_cpy(ctx0,
1560
- // Vcur,
1561
- // ggml_new_tensor_3d(ctx0, wctx.wtype, n_state/n_head, n_head, n_ctx)),
1562
- // 1, 2, 0, 3);
1563
-
1564
- // struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
1565
-
1566
1557
struct ggml_tensor * V =
1567
1558
ggml_cpy (ctx0,
1568
1559
ggml_permute (ctx0,
1569
1560
ggml_reshape_3d (ctx0,
1570
1561
Vcur,
1571
1562
n_state/n_head, n_head, n_ctx),
1572
- 0 , 2 , 1 , 3 ),
1573
- ggml_new_tensor_3d (ctx0, wctx.wtype , n_state/n_head, n_ctx , n_head)
1563
+ 1 , 2 , 0 , 3 ),
1564
+ ggml_new_tensor_3d (ctx0, wctx.wtype , n_ctx, n_state/n_head, n_head)
1574
1565
);
1575
1566
1576
- struct ggml_tensor * KQV = ggml_mul_mat (ctx0, ggml_transpose (ctx0, V) , KQ_soft_max);
1567
+ struct ggml_tensor * KQV = ggml_mul_mat (ctx0, V , KQ_soft_max);
1577
1568
#endif
1578
1569
struct ggml_tensor * KQV_merged = ggml_permute (ctx0, KQV, 0 , 2 , 1 , 3 );
1579
1570
0 commit comments