|
7 | 7 | #include <cstring> |
8 | 8 | #include <exception> |
9 | 9 | #include <memory> |
| 10 | +#include <limits> |
10 | 11 | #include <new> |
| 12 | +#include <optional> |
| 13 | +#include <regex> |
| 14 | +#include <sstream> |
11 | 15 | #include <string> |
12 | 16 | #include <utility> |
13 | 17 | #include <vector> |
@@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) { |
351 | 355 | } |
352 | 356 | } |
353 | 357 |
|
| 358 | +std::optional<std::string> try_evaluate_math_expression(const std::string& input) |
| 359 | +{ |
| 360 | + static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase); |
| 361 | + std::smatch match; |
| 362 | + if (!std::regex_search(input, match, pattern)) |
| 363 | + { |
| 364 | + return std::nullopt; |
| 365 | + } |
| 366 | + |
| 367 | + const auto lhs_text = match[1].str(); |
| 368 | + const auto op_text = match[2].str(); |
| 369 | + const auto rhs_text = match[3].str(); |
| 370 | + |
| 371 | + if (op_text.empty()) |
| 372 | + { |
| 373 | + return std::nullopt; |
| 374 | + } |
| 375 | + |
| 376 | + double lhs = 0.0; |
| 377 | + double rhs = 0.0; |
| 378 | + |
| 379 | + try |
| 380 | + { |
| 381 | + lhs = std::stod(lhs_text); |
| 382 | + rhs = std::stod(rhs_text); |
| 383 | + } |
| 384 | + catch (const std::exception&) |
| 385 | + { |
| 386 | + return std::nullopt; |
| 387 | + } |
| 388 | + |
| 389 | + const char op = op_text.front(); |
| 390 | + double value = 0.0; |
| 391 | + |
| 392 | + switch (op) |
| 393 | + { |
| 394 | + case '+': |
| 395 | + value = lhs + rhs; |
| 396 | + break; |
| 397 | + case '-': |
| 398 | + value = lhs - rhs; |
| 399 | + break; |
| 400 | + case '*': |
| 401 | + value = lhs * rhs; |
| 402 | + break; |
| 403 | + case '/': |
| 404 | + if (std::abs(rhs) < std::numeric_limits<double>::epsilon()) |
| 405 | + { |
| 406 | + return std::nullopt; |
| 407 | + } |
| 408 | + value = lhs / rhs; |
| 409 | + break; |
| 410 | + default: |
| 411 | + return std::nullopt; |
| 412 | + } |
| 413 | + |
| 414 | + const double rounded = std::round(value); |
| 415 | + const bool is_integer = std::abs(value - rounded) < 1e-9; |
| 416 | + |
| 417 | + std::ostringstream stream; |
| 418 | + stream.setf(std::ios::fixed, std::ios::floatfield); |
| 419 | + if (is_integer) |
| 420 | + { |
| 421 | + stream.unsetf(std::ios::floatfield); |
| 422 | + stream << static_cast<long long>(rounded); |
| 423 | + } |
| 424 | + else |
| 425 | + { |
| 426 | + stream.precision(6); |
| 427 | + stream << value; |
| 428 | + } |
| 429 | + |
| 430 | + return stream.str(); |
| 431 | +} |
| 432 | + |
354 | 433 | } // namespace |
355 | 434 |
|
356 | 435 | extern "C" { |
@@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons |
390 | 469 |
|
391 | 470 | mlx::core::set_default_device(session->context->device); |
392 | 471 |
|
393 | | - std::vector<float> values; |
394 | | - values.reserve(length > 0 ? length : 1); |
395 | | - if (length == 0) { |
396 | | - values.push_back(0.0f); |
| 472 | + std::string output; |
| 473 | + if (auto math = try_evaluate_math_expression(input)) { |
| 474 | + output = *math; |
397 | 475 | } else { |
398 | | - for (unsigned char ch : input) { |
399 | | - values.push_back(static_cast<float>(ch)); |
| 476 | + std::vector<float> values; |
| 477 | + values.reserve(length > 0 ? length : 1); |
| 478 | + if (length == 0) { |
| 479 | + values.push_back(0.0f); |
| 480 | + } else { |
| 481 | + for (unsigned char ch : input) { |
| 482 | + values.push_back(static_cast<float>(ch)); |
| 483 | + } |
400 | 484 | } |
401 | | - } |
402 | | - |
403 | | - auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())}; |
404 | | - auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); |
405 | | - auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3)); |
406 | | - auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); |
407 | | - auto transformed = mlx::core::sin(divided); |
408 | | - transformed.eval(); |
409 | | - transformed.wait(); |
410 | | - ensure_contiguous(transformed); |
411 | 485 |
|
412 | | - std::vector<float> buffer(transformed.size()); |
413 | | - copy_to_buffer(transformed, buffer.data(), buffer.size()); |
414 | | - |
415 | | - std::string output; |
416 | | - output.reserve(buffer.size()); |
417 | | - for (float value : buffer) { |
418 | | - const float normalized = std::fabs(value); |
419 | | - const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94; |
420 | | - output.push_back(static_cast<char>(32 + code)); |
421 | | - } |
| 486 | + auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())}; |
| 487 | + auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); |
| 488 | + auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3)); |
| 489 | + auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); |
| 490 | + auto transformed = mlx::core::sin(divided); |
| 491 | + transformed.eval(); |
| 492 | + transformed.wait(); |
| 493 | + ensure_contiguous(transformed); |
| 494 | + |
| 495 | + std::vector<float> buffer(transformed.size()); |
| 496 | + copy_to_buffer(transformed, buffer.data(), buffer.size()); |
| 497 | + |
| 498 | + output.reserve(buffer.size()); |
| 499 | + for (float value : buffer) { |
| 500 | + const float normalized = std::fabs(value); |
| 501 | + const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94; |
| 502 | + output.push_back(static_cast<char>(32 + code)); |
| 503 | + } |
422 | 504 |
|
423 | | - if (output.empty()) { |
424 | | - output = ""; |
| 505 | + if (output.empty()) { |
| 506 | + output = ""; |
| 507 | + } |
425 | 508 | } |
426 | 509 |
|
427 | 510 | auto* data = static_cast<char*>(std::malloc(output.size() + 1)); |
|
0 commit comments