Skip to content

Commit 4893c49

Browse files
committed
Improve stubbed text generation and validate native exports
1 parent 0503a45 commit 4893c49

File tree

2 files changed

+147
-28
lines changed

2 files changed

+147
-28
lines changed

native/src/mlxsharp.cpp

Lines changed: 111 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
#include <cstring>
88
#include <exception>
99
#include <memory>
10+
#include <limits>
1011
#include <new>
12+
#include <optional>
13+
#include <regex>
14+
#include <sstream>
1115
#include <string>
1216
#include <utility>
1317
#include <vector>
@@ -351,6 +355,81 @@ void ensure_contiguous(const mlx::core::array& arr) {
351355
}
352356
}
353357

358+
std::optional<std::string> try_evaluate_math_expression(const std::string& input)
359+
{
360+
static const std::regex pattern(R"(([-+]?\d+(?:\.\d+)?)\s*([+\-*/])\s*([-+]?\d+(?:\.\d+)?))", std::regex::icase);
361+
std::smatch match;
362+
if (!std::regex_search(input, match, pattern))
363+
{
364+
return std::nullopt;
365+
}
366+
367+
const auto lhs_text = match[1].str();
368+
const auto op_text = match[2].str();
369+
const auto rhs_text = match[3].str();
370+
371+
if (op_text.empty())
372+
{
373+
return std::nullopt;
374+
}
375+
376+
double lhs = 0.0;
377+
double rhs = 0.0;
378+
379+
try
380+
{
381+
lhs = std::stod(lhs_text);
382+
rhs = std::stod(rhs_text);
383+
}
384+
catch (const std::exception&)
385+
{
386+
return std::nullopt;
387+
}
388+
389+
const char op = op_text.front();
390+
double value = 0.0;
391+
392+
switch (op)
393+
{
394+
case '+':
395+
value = lhs + rhs;
396+
break;
397+
case '-':
398+
value = lhs - rhs;
399+
break;
400+
case '*':
401+
value = lhs * rhs;
402+
break;
403+
case '/':
404+
if (std::abs(rhs) < std::numeric_limits<double>::epsilon())
405+
{
406+
return std::nullopt;
407+
}
408+
value = lhs / rhs;
409+
break;
410+
default:
411+
return std::nullopt;
412+
}
413+
414+
const double rounded = std::round(value);
415+
const bool is_integer = std::abs(value - rounded) < 1e-9;
416+
417+
std::ostringstream stream;
418+
stream.setf(std::ios::fixed, std::ios::floatfield);
419+
if (is_integer)
420+
{
421+
stream.unsetf(std::ios::floatfield);
422+
stream << static_cast<long long>(rounded);
423+
}
424+
else
425+
{
426+
stream.precision(6);
427+
stream << value;
428+
}
429+
430+
return stream.str();
431+
}
432+
354433
} // namespace
355434

356435
extern "C" {
@@ -390,38 +469,42 @@ int mlxsharp_generate_text(void* session_ptr, const char* prompt, char** respons
390469

391470
mlx::core::set_default_device(session->context->device);
392471

393-
std::vector<float> values;
394-
values.reserve(length > 0 ? length : 1);
395-
if (length == 0) {
396-
values.push_back(0.0f);
472+
std::string output;
473+
if (auto math = try_evaluate_math_expression(input)) {
474+
output = *math;
397475
} else {
398-
for (unsigned char ch : input) {
399-
values.push_back(static_cast<float>(ch));
476+
std::vector<float> values;
477+
values.reserve(length > 0 ? length : 1);
478+
if (length == 0) {
479+
values.push_back(0.0f);
480+
} else {
481+
for (unsigned char ch : input) {
482+
values.push_back(static_cast<float>(ch));
483+
}
400484
}
401-
}
402-
403-
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
404-
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
405-
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
406-
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
407-
auto transformed = mlx::core::sin(divided);
408-
transformed.eval();
409-
transformed.wait();
410-
ensure_contiguous(transformed);
411485

412-
std::vector<float> buffer(transformed.size());
413-
copy_to_buffer(transformed, buffer.data(), buffer.size());
414-
415-
std::string output;
416-
output.reserve(buffer.size());
417-
for (float value : buffer) {
418-
const float normalized = std::fabs(value);
419-
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
420-
output.push_back(static_cast<char>(32 + code));
421-
}
486+
auto shape = mlx::core::Shape{static_cast<mlx::core::ShapeElem>(values.size())};
487+
auto arr = make_array(values.data(), values.size(), shape, mlx::core::float32);
488+
auto scale = mlx::core::array(static_cast<float>((values.size() % 17) + 3));
489+
auto divided = mlx::core::divide(mlx::core::add(arr, scale), scale);
490+
auto transformed = mlx::core::sin(divided);
491+
transformed.eval();
492+
transformed.wait();
493+
ensure_contiguous(transformed);
494+
495+
std::vector<float> buffer(transformed.size());
496+
copy_to_buffer(transformed, buffer.data(), buffer.size());
497+
498+
output.reserve(buffer.size());
499+
for (float value : buffer) {
500+
const float normalized = std::fabs(value);
501+
const int code = static_cast<int>(std::round(normalized * 94.0f)) % 94;
502+
output.push_back(static_cast<char>(32 + code));
503+
}
422504

423-
if (output.empty()) {
424-
output = "";
505+
if (output.empty()) {
506+
output = "";
507+
}
425508
}
426509

427510
auto* data = static_cast<char*>(std::malloc(output.size() + 1));

src/MLXSharp.Tests/ArraySmokeTests.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.IO;
4+
using System.Runtime.InteropServices;
45
using MLXSharp.Core;
56
using Xunit;
67

@@ -76,6 +77,13 @@ public static bool TryEnsure(out string? skipReason)
7677
return false;
7778
}
7879

80+
if (!HasRequiredExports(path, out skipReason))
81+
{
82+
s_initialized = true;
83+
s_available = false;
84+
return false;
85+
}
86+
7987
Environment.SetEnvironmentVariable("MLXSHARP_LIBRARY", path);
8088
s_initialized = true;
8189
s_available = true;
@@ -84,6 +92,34 @@ public static bool TryEnsure(out string? skipReason)
8492
}
8593
}
8694

95+
private static bool HasRequiredExports(string path, out string? reason)
96+
{
97+
if (!NativeLibrary.TryLoad(path, out var handle))
98+
{
99+
reason = $"Unable to load native library from '{path}'.";
100+
return false;
101+
}
102+
103+
try
104+
{
105+
foreach (var export in new[] { "mlxsharp_context_create", "mlxsharp_array_from_buffer", "mlxsharp_generate_text" })
106+
{
107+
if (!NativeLibrary.TryGetExport(handle, export, out _))
108+
{
109+
reason = $"Native library at '{path}' is missing required export '{export}'. Rebuild MLXSharp native binaries.";
110+
return false;
111+
}
112+
}
113+
114+
reason = null;
115+
return true;
116+
}
117+
finally
118+
{
119+
NativeLibrary.Free(handle);
120+
}
121+
}
122+
87123
private static bool TryFindNativeLibrary(out string path)
88124
{
89125
var baseDir = AppContext.BaseDirectory;

0 commit comments

Comments
 (0)