Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support dynamic native library loading in .NET standard 2.0. #738

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions LLama.Examples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama.Native;
using LLama.Native;
using Spectre.Console;
using System.Runtime.InteropServices;

AnsiConsole.MarkupLineInterpolated(
$"""
Expand All @@ -16,23 +17,24 @@ __ __ ____ __

""");

// Configure native library to use. This must be done before any other llama.cpp methods are called!
NativeLibraryConfig
.Instance
.WithCuda();

// Configure logging. Change this to `true` to see log messages from llama.cpp
var showLLamaCppLogs = false;
NativeLibraryConfig
.Instance
.All
.WithLogCallback((level, message) =>
{
if (showLLamaCppLogs)
Console.WriteLine($"[llama {level}]: {message.TrimEnd('\n')}");
});
{
if (showLLamaCppLogs)
Console.WriteLine($"[llama {level}]: {message.TrimEnd('\n')}");
});

// Configure native library to use. This must be done before any other llama.cpp methods are called!
NativeLibraryConfig
.All
.WithCuda();
//.WithAutoDownload() // An experimental feature
//.DryRun(out var loadedllamaLibrary, out var loadedLLavaLibrary);

// Calling this method forces loading to occur now.
NativeApi.llama_empty_call();

await ExampleRunner.Run();

await ExampleRunner.Run();
29 changes: 29 additions & 0 deletions LLama/Abstractions/INativeLibrary.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Abstractions
{
/// <summary>
/// Descriptor of a native library.
/// </summary>
public interface INativeLibrary
{
/// <summary>
/// Metadata of this library.
/// </summary>
NativeLibraryMetadata? Metadata { get; }

/// <summary>
/// Prepare the native library file and returns the local path of it.
/// If it's a relative path, LLamaSharp will search the path in the search directies you set.
/// </summary>
/// <param name="systemInfo">The system information of the current machine.</param>
/// <param name="logCallback">The log callback.</param>
/// <returns>
/// The relative paths of the library. You could return multiple paths to try them one by one. If no file is available, please return an empty array.
/// </returns>
IEnumerable<string> Prepare(SystemInfo systemInfo, NativeLogConfig.LLamaLogCallback? logCallback = null);
}
}
22 changes: 22 additions & 0 deletions LLama/Abstractions/INativeLibrarySelectingPolicy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Abstractions
{
/// <summary>
/// Decides the selected native library that should be loaded according to the configurations.
/// </summary>
public interface INativeLibrarySelectingPolicy
{
/// <summary>
/// Select the native library.
/// </summary>
/// <param name="description"></param>
/// <param name="systemInfo">The system information of the current machine.</param>
/// <param name="logCallback">The log callback.</param>
/// <returns>The information of the selected native library files, in order by priority from the beginning to the end.</returns>
IEnumerable<INativeLibrary> Apply(NativeLibraryConfig.Description description, SystemInfo systemInfo, NativeLogConfig.LLamaLogCallback? logCallback = null);
}
}
5 changes: 3 additions & 2 deletions LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>netstandard2.0;net6.0;net7.0;net8.0</TargetFrameworks>
<TargetFrameworks>net6.0;net7.0;net8.0;netstandard2.0</TargetFrameworks>
<RootNamespace>LLama</RootNamespace>
<Nullable>enable</Nullable>
<LangVersion>10</LangVersion>
<LangVersion>12</LangVersion>
<Platforms>AnyCPU;x64;Arm64</Platforms>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>

Expand Down Expand Up @@ -42,6 +42,7 @@
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="NativeLibrary.NetStandard" Version="0.1.1" PrivateAssets="all" />
<PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" />
<PackageReference Include="System.Memory" Version="4.5.5" PrivateAssets="all" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
Expand Down
21 changes: 17 additions & 4 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using System;
using System;
using System.Runtime.InteropServices;

#if NETSTANDARD
using NativeLibraryNetStandard;
#endif

namespace LLama.Native
{
/// <summary>
Expand Down Expand Up @@ -180,10 +184,19 @@ public bool flash_attention
public static LLamaContextParams Default()
{
return llama_context_default_params();

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern LLamaContextParams llama_context_default_params();
}

#if NETSTANDARD
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern LLamaContextParams llama_context_default_params_r();
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
delegate LLamaContextParams llama_context_default_params_t();
static LLamaContextParams llama_context_default_params() => NativeLibraryConfig.DynamicLoadingDisabled ?
llama_context_default_params_r() : NativeApi.GetLLamaExport<llama_context_default_params_t>("llama_context_default_params")();
#else
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern LLamaContextParams llama_context_default_params();
#endif
}
}

47 changes: 47 additions & 0 deletions LLama/Native/LLamaKvCacheView.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using System;
using System.Runtime.InteropServices;

#if NETSTANDARD
using NativeLibraryNetStandard;
#endif

namespace LLama.Native;

/// <summary>
Expand Down Expand Up @@ -151,6 +155,48 @@ public Span<LLamaSeqId> GetCellSequences(int index)
}

#region native API

#if NETSTANDARD
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern NativeLLamaKvCacheView llama_kv_cache_view_init_r(SafeLLamaContextHandle ctx, int n_seq_max);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate NativeLLamaKvCacheView llama_kv_cache_view_init_t(SafeLLamaContextHandle ctx, int n_seq_max);
private static NativeLLamaKvCacheView llama_kv_cache_view_init(SafeLLamaContextHandle ctx, int n_seq_max) =>
NativeLibraryConfig.DynamicLoadingDisabled ? llama_kv_cache_view_init_r(ctx, n_seq_max) : NativeApi.GetLLamaExport<llama_kv_cache_view_init_t>("llama_kv_cache_view_init")(ctx, n_seq_max);

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_free_r(ref NativeLLamaKvCacheView view);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void llama_kv_cache_view_free_t(ref NativeLLamaKvCacheView view);
private static void llama_kv_cache_view_free(ref NativeLLamaKvCacheView view)
{
if (NativeLibraryConfig.DynamicLoadingDisabled)
{
llama_kv_cache_view_free_r(ref view);
}
else
{
NativeApi.GetLLamaExport<llama_kv_cache_view_free_t>("llama_kv_cache_view_free")(ref view);
}
}

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_update_r(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
private delegate void llama_kv_cache_view_update_t(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
private static void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view)
{
if (NativeLibraryConfig.DynamicLoadingDisabled)
{
llama_kv_cache_view_update_r(ctx, ref view);
}
else
{
NativeApi.GetLLamaExport<llama_kv_cache_view_update_t>("llama_kv_cache_view_update")(ctx, ref view);
}
}

#else
/// <summary>
/// Create an empty KV cache view. (use only for debugging purposes)
/// </summary>
Expand All @@ -173,6 +219,7 @@ public Span<LLamaSeqId> GetCellSequences(int index)
/// <param name="view"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref NativeLLamaKvCacheView view);
#endif

/// <summary>
/// Information associated with an individual cell in the KV cache view (llama_kv_cache_view_cell)
Expand Down
21 changes: 17 additions & 4 deletions LLama/Native/LLamaModelQuantizeParams.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using System;
using System;
using System.Runtime.InteropServices;

#if NETSTANDARD
using NativeLibraryNetStandard;
#endif

namespace LLama.Native
{
/// <summary>
Expand Down Expand Up @@ -97,9 +101,18 @@ public bool keep_split
public static LLamaModelQuantizeParams Default()
{
return llama_model_quantize_default_params();

[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern LLamaModelQuantizeParams llama_model_quantize_default_params();
}

#if NETSTANDARD
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern LLamaModelQuantizeParams llama_model_quantize_default_params_r();
[UnmanagedFunctionPointer(CallingConvention.Cdecl)]
delegate LLamaModelQuantizeParams llama_model_quantize_default_params_t();
static LLamaModelQuantizeParams llama_model_quantize_default_params() => NativeLibraryConfig.DynamicLoadingDisabled ?
llama_model_quantize_default_params_r() : NativeApi.GetLLamaExport<llama_model_quantize_default_params_t>("llama_model_quantize_default_params")();
#else
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
static extern LLamaModelQuantizeParams llama_model_quantize_default_params();
#endif
}
}
67 changes: 67 additions & 0 deletions LLama/Native/Load/DefaultNativeLibrarySelectingPolicy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using LLama.Abstractions;
using System.Collections.Generic;
using System.Runtime.InteropServices;

namespace LLama.Native
{
/// <inheritdoc/>
public class DefaultNativeLibrarySelectingPolicy: INativeLibrarySelectingPolicy
{
/// <inheritdoc/>
public IEnumerable<INativeLibrary> Apply(NativeLibraryConfig.Description description, SystemInfo systemInfo, NativeLogConfig.LLamaLogCallback? logCallback)
{
List<INativeLibrary> results = new();

// Show the configuration we're working with
Log(description.ToString(), LLamaLogLevel.Info, logCallback);

// If a specific path is requested, only use it, no fall back.
if (!string.IsNullOrEmpty(description.Path))
{
yield return new NativeLibraryFromPath(description.Path);
}
else
{
if (description.UseCuda)
{
yield return new NativeLibraryWithCuda(systemInfo.CudaMajorVersion, description.Library, description.SkipCheck);
}

if(!description.UseCuda || description.AllowFallback)
{
if (description.AllowFallback)
{
// Try all of the AVX levels we can support.
if (description.AvxLevel >= AvxLevel.Avx512)
yield return new NativeLibraryWithAvx(description.Library, AvxLevel.Avx512, description.SkipCheck);

if (description.AvxLevel >= AvxLevel.Avx2)
yield return new NativeLibraryWithAvx(description.Library, AvxLevel.Avx2, description.SkipCheck);

if (description.AvxLevel >= AvxLevel.Avx)
yield return new NativeLibraryWithAvx(description.Library, AvxLevel.Avx, description.SkipCheck);

yield return new NativeLibraryWithAvx(description.Library, AvxLevel.None, description.SkipCheck);
}
else
{
yield return new NativeLibraryWithAvx(description.Library, description.AvxLevel, description.SkipCheck);
}
}

if(systemInfo.OSPlatform == OSPlatform.OSX || description.AllowFallback)
{
yield return new NativeLibraryWithMacOrFallback(description.Library, description.SkipCheck);
}
}
}

private void Log(string message, LLamaLogLevel level, NativeLogConfig.LLamaLogCallback? logCallback)
{
if (!message.EndsWith("\n"))
message += "\n";

logCallback?.Invoke(level, message);
}
}
}
Loading