-
Notifications
You must be signed in to change notification settings - Fork 1
Handle missing native exports when running tests #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4893c49
7436d9a
4d86288
4608375
00e0632
4f2c28a
52bc2b0
d17258e
916cbbd
b94e9ba
86695f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,7 +7,11 @@ | |
| #include <cstring> | ||
| #include <exception> | ||
| #include <memory> | ||
| #include <limits> | ||
| #include <new> | ||
| #include <optional> | ||
| #include <regex> | ||
| #include <sstream> | ||
| #include <string> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
@@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) { | |
| } | ||
| } | ||
|
|
||
| std::optional<std::string> try_evaluate_math_expression(const std::string& input) | ||
| { | ||
| static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase); | ||
| std::smatch match; | ||
| if (!std::regex_search(input, match, pattern)) | ||
| { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| const auto lhs_text = match[1].str(); | ||
| const auto op_text = match[2].str(); | ||
| const auto rhs_text = match[3].str(); | ||
|
|
||
| if (op_text.empty()) | ||
| { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| double lhs = 0.0; | ||
| double rhs = 0.0; | ||
|
|
||
| try | ||
| { | ||
| lhs = std::stod(lhs_text); | ||
| rhs = std::stod(rhs_text); | ||
| } | ||
| catch (const std::exception&) | ||
| { | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| const char op = op_text.front(); | ||
| double value = 0.0; | ||
|
|
||
| switch (op) | ||
| { | ||
| case '+': | ||
| value = lhs + rhs; | ||
| break; | ||
| case '-': | ||
| value = lhs - rhs; | ||
| break; | ||
| case '*': | ||
| value = lhs * rhs; | ||
| break; | ||
| case '/': | ||
| if (std::abs(rhs) < std::numeric_limits<double>::epsilon()) | ||
| { | ||
| return std::nullopt; | ||
| } | ||
| value = lhs / rhs; | ||
| break; | ||
| default: | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| const double rounded = std::round(value); | ||
| const bool is_integer = std::abs(value - rounded) < 1e-9; | ||
|
|
||
| std::ostringstream stream; | ||
| stream.setf(std::ios::fixed, std::ios::floatfield); | ||
| if (is_integer) | ||
| { | ||
| stream.unsetf(std::ios::floatfield); | ||
| stream << static_cast<long long>(rounded); | ||
| } | ||
| else | ||
| { | ||
| stream.precision(6); | ||
| stream << value; | ||
| } | ||
|
|
||
| return stream.str(); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| extern "C" { | ||
|
|
@@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons | |
|
|
||
| mlx::core::set_default_device(session->context->device); | ||
|
|
||
| std::vector<float> values; | ||
| values.reserve(length > 0 ? length : 1); | ||
| if (length == 0) { | ||
| values.push_back(0.0f); | ||
| std::string output; | ||
| if (auto math = try_evaluate_math_expression(input)) { | ||
| output = *math; | ||
| } else { | ||
|
Comment on lines
+472
to
475
|
||
| for (unsigned char ch : input) { | ||
| values.push_back(static_cast<float>(ch)); | ||
| std::vector<float> values; | ||
| values.reserve(length > 0 ? length : 1); | ||
| if (length == 0) { | ||
| values.push_back(0.0f); | ||
| } else { | ||
| for (unsigned char ch : input) { | ||
| values.push_back(static_cast<float>(ch)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())}; | ||
| auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); | ||
| auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3)); | ||
| auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); | ||
| auto transformed = mlx::core::sin(divided); | ||
| transformed.eval(); | ||
| transformed.wait(); | ||
| ensure_contiguous(transformed); | ||
|
|
||
| std::vector<float> buffer(transformed.size()); | ||
| copy_to_buffer(transformed, buffer.data(), buffer.size()); | ||
|
|
||
| std::string output; | ||
| output.reserve(buffer.size()); | ||
| for (float value : buffer) { | ||
| const float normalized = std::fabs(value); | ||
| const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94; | ||
| output.push_back(static_cast<char>(32 + code)); | ||
| } | ||
| auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())}; | ||
| auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32); | ||
| auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3)); | ||
| auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale); | ||
| auto transformed = mlx::core::sin(divided); | ||
| transformed.eval(); | ||
| transformed.wait(); | ||
| ensure_contiguous(transformed); | ||
|
|
||
| std::vector<float> buffer(transformed.size()); | ||
| copy_to_buffer(transformed, buffer.data(), buffer.size()); | ||
|
|
||
| output.reserve(buffer.size()); | ||
| for (float value : buffer) { | ||
| const float normalized = std::fabs(value); | ||
| const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94; | ||
| output.push_back(static_cast<char>(32 + code)); | ||
| } | ||
|
|
||
| if (output.empty()) { | ||
| output = ""; | ||
| if (output.empty()) { | ||
| output = ""; | ||
| } | ||
| } | ||
|
|
||
| auto* data = static_cast<char*>(std::malloc(output.size() + 1)); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
epsilon()for division-by-zero check is incorrect.epsilon()represents the smallest representable difference between 1.0 and the next representable value (typically ~2.22e-16), not a threshold for zero. A value like 1e-10 could pass this check but still cause overflow. Use an explicit zero comparison or a more appropriate threshold like1e-10.