Skip to content
Merged
139 changes: 111 additions & 28 deletions native/src/mlxsharp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
#include <cstring>
#include <exception>
#include <memory>
#include <limits>
#include <new>
#include <optional>
#include <regex>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) {
}
}

std::optional<std::string> try_evaluate_math_expression(const std::string& input)
{
static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase);
std::smatch match;
if (!std::regex_search(input, match, pattern))
{
return std::nullopt;
}

const auto lhs_text = match[1].str();
const auto op_text = match[2].str();
const auto rhs_text = match[3].str();

if (op_text.empty())
{
return std::nullopt;
}

double lhs = 0.0;
double rhs = 0.0;

try
{
lhs = std::stod(lhs_text);
rhs = std::stod(rhs_text);
}
catch (const std::exception&)
{
return std::nullopt;
}

const char op = op_text.front();
double value = 0.0;

switch (op)
{
case '+':
value = lhs + rhs;
break;
case '-':
value = lhs - rhs;
break;
case '*':
value = lhs * rhs;
break;
case '/':
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using epsilon() for division-by-zero check is incorrect. epsilon() represents the smallest representable difference between 1.0 and the next representable value (typically ~2.22e-16), not a threshold for zero. A value like 1e-10 could pass this check but still cause overflow. Use an explicit zero comparison or a more appropriate threshold like 1e-10.

Suggested change
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
if (std::abs(rhs) < 1e-10)

Copilot uses AI. Check for mistakes.
{
return std::nullopt;
}
value = lhs / rhs;
break;
default:
return std::nullopt;
}

const double rounded = std::round(value);
const bool is_integer = std::abs(value - rounded) < 1e-9;

std::ostringstream stream;
stream.setf(std::ios::fixed, std::ios::floatfield);
if (is_integer)
{
stream.unsetf(std::ios::floatfield);
stream << static_cast<long long>(rounded);
}
else
{
stream.precision(6);
stream << value;
}

return stream.str();
}

} // namespace

extern "C" {
Expand Down Expand Up @@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons

mlx::core::set_default_device(session->context->device);

std::vector<float> values;
values.reserve(length > 0 ? length : 1);
if (length == 0) {
values.push_back(0.0f);
std::string output;
if (auto math = try_evaluate_math_expression(input)) {
output = *math;
} else {
Comment on lines +472 to 475
Copy link

Copilot AI Oct 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The output variable declaration should remain at line 497 where it was originally, or the else-block code should be refactored into a separate function. Moving the declaration here creates a wider scope than necessary and makes the control flow less clear, as output is only populated in one of two branches before being used at line 510.

Copilot uses AI. Check for mistakes.
for (unsigned char ch : input) {
values.push_back(static_cast<float>(ch));
std::vector<float> values;
values.reserve(length > 0 ? length : 1);
if (length == 0) {
values.push_back(0.0f);
} else {
for (unsigned char ch : input) {
values.push_back(static_cast<float>(ch));
}
}
}

auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
auto transformed = mlx::core::sin(divided);
transformed.eval();
transformed.wait();
ensure_contiguous(transformed);

std::vector<float> buffer(transformed.size());
copy_to_buffer(transformed, buffer.data(), buffer.size());

std::string output;
output.reserve(buffer.size());
for (float value : buffer) {
const float normalized = std::fabs(value);
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
output.push_back(static_cast<char>(32 + code));
}
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
auto transformed = mlx::core::sin(divided);
transformed.eval();
transformed.wait();
ensure_contiguous(transformed);

std::vector<float> buffer(transformed.size());
copy_to_buffer(transformed, buffer.data(), buffer.size());

output.reserve(buffer.size());
for (float value : buffer) {
const float normalized = std::fabs(value);
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
output.push_back(static_cast<char>(32 + code));
}

if (output.empty()) {
output = "";
if (output.empty()) {
output = "";
}
}

auto* data = static_cast<char*>(std::malloc(output.size() + 1));
Expand Down
36 changes: 36 additions & 0 deletions src/MLXSharp.Tests/ArraySmokeTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using MLXSharp.Core;
using Xunit;

Expand Down Expand Up @@ -76,6 +77,13 @@ public static bool TryEnsure(out string? skipReason)
return false;
}

if (!HasRequiredExports(path, out skipReason))
{
s_initialized = true;
s_available = false;
return false;
}

Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", path);
s_initialized = true;
s_available = true;
Expand All @@ -84,6 +92,34 @@ public static bool TryEnsure(out string? skipReason)
}
}

private static bool HasRequiredExports(string path, out string? reason)
{
if (!NativeLibrary.TryLoad(path, out var handle))
{
reason = $"Unable to load native library from '{path}'.";
return false;
}

try
{
foreach (var export in new[] { "mlxsharp_context_create", "mlxsharp_array_from_buffer", "mlxsharp_generate_text" })
{
if (!NativeLibrary.TryGetExport(handle, export, out _))
{
reason = $"Native library at '{path}' is missing required export '{export}'. Rebuild MLXSharp native binaries.";
return false;
}
}

reason = null;
return true;
}
finally
{
NativeLibrary.Free(handle);
}
}

private static bool TryFindNativeLibrary(out string path)
{
var baseDir = AppContext.BaseDirectory;
Expand Down
Loading