Skip to content
This repository was archived by the owner on Mar 12, 2020. It is now read-only.

Commit

Permalink
TF.NET backend in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Deepak Kumar committed Apr 16, 2019
1 parent 029d84f commit bf470be
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="TensorFlow.NET" Version="0.5.2" />
<PackageReference Include="TensorFlow.NET" Version="0.6.0" />
</ItemGroup>

<ItemGroup>
Expand Down
100 changes: 61 additions & 39 deletions Backends/SiaNet.Backend.TensorFlow/SiaNetBackend.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public Engine.DataType GetDataType(SiaTensor x)

public SiaTensor CreateVariable(float[] data, long[] shape, string name = "")
{
return null;
return Out(tf.Variable<Array>(data).value());
}

public SiaTensor Reshape(SiaTensor x, params long[] shape)
Expand Down Expand Up @@ -348,7 +348,7 @@ public SiaTensor Round(SiaTensor x)

public SiaTensor Sin(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.sin(In(x)));
}

public SiaTensor Cos(SiaTensor x)
Expand All @@ -358,37 +358,37 @@ public SiaTensor Cos(SiaTensor x)

public SiaTensor Tan(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.tan(In(x)));
}

public SiaTensor Asin(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.asin(In(x)));
}

public SiaTensor Acos(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.acos(In(x)));
}

public SiaTensor Atan(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.atan(In(x)));
}

public SiaTensor Sinh(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.sinh(In(x)));
}

public SiaTensor Cosh(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.cosh(In(x)));
}

public SiaTensor Tanh(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.tanh(In(x)));
}

public SiaTensor Sigmoid(SiaTensor x)
Expand All @@ -398,47 +398,60 @@ public SiaTensor Sigmoid(SiaTensor x)

public SiaTensor Pow(SiaTensor x, float value)
{
throw new NotImplementedException();
return Out(tf.pow(In(x), In(value, x.Shape)));
}

public SiaTensor Square(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.pow(In(x), 2));
}

public SiaTensor Clip(SiaTensor x, float min, float max)
{
throw new NotImplementedException();
return Out(tf._clip_by_value(In(x), In(min, x.Shape), In(max, x.Shape)));
}

public float Sum(SiaTensor x)
{
throw new NotImplementedException();
return tf.reduce_sum(In(x));
}

public SiaTensor Sum(SiaTensor x, int dim)
{
throw new NotImplementedException();
dim = dim < 0 ? x.DimCount + dim : dim;
return Out(tf.sum(In(x), dim, true));
}

public SiaTensor Sum(SiaTensor x, params int[] dim)
public SiaTensor Sum(SiaTensor x, params int[] dims)
{
throw new NotImplementedException();
foreach (var item in dims)
{
int dim = item < 0 ? x.DimCount + item : item;
x = Sum(x, item);
}

return x;
}

public float Max(SiaTensor x)
{
throw new NotImplementedException();
return math_ops.reduce_max(In(x));
}

public SiaTensor Max(SiaTensor x, int dim)
{
throw new NotImplementedException();
dim = dim < 0 ? x.DimCount + dim : dim;
return Out(math_ops.reduce_max(In(x), new int[] { dim }, true));
}

public SiaTensor Max(SiaTensor x, params int[] dim)
public SiaTensor Max(SiaTensor x, params int[] dims)
{
throw new NotImplementedException();
for (int i = 0; i < dims.Length; i++)
{
dims[i] = dims[i] < 0 ? x.DimCount + dims[i] : dims[i];
}

return Out(math_ops.reduce_max(In(x), dims, true));
}

public float Min(SiaTensor x)
Expand All @@ -458,62 +471,70 @@ public SiaTensor Min(SiaTensor x, params int[] dim)

public float Mean(SiaTensor x)
{
throw new NotImplementedException();
return math_ops.reduce_mean(In(x));
}

public SiaTensor Mean(SiaTensor x, int dim)
{
throw new NotImplementedException();
dim = dim < 0 ? x.DimCount + dim : dim;
return Out(math_ops.reduce_mean(In(x), new int[] { dim }, true));
}

public SiaTensor Mean(SiaTensor x, params int[] dim)
public SiaTensor Mean(SiaTensor x, params int[] dims)
{
throw new NotImplementedException();
for (int i = 0; i < dims.Length; i++)
{
dims[i] = dims[i] < 0 ? x.DimCount + dims[i] : dims[i];
}

return Out(math_ops.reduce_mean(In(x), dims, true));
}

public SiaTensor Argmax(SiaTensor x, int dim = 0)
{
throw new NotImplementedException();
dim = dim < 0 ? x.DimCount + dim : dim;
return Out(tf.arg_max(In(x), dim));
}

public SiaTensor Argmin(SiaTensor x, int dim = 0)
{
throw new NotImplementedException();
dim = dim < 0 ? x.DimCount + dim : dim;
return Out(tf.arg_min(In(x), dim));
}

public SiaTensor Maximum(SiaTensor a, SiaTensor b)
{
throw new NotImplementedException();
return Out(tf.maximum(In(a), In(b)));
}

public SiaTensor Maximum(SiaTensor a, float b)
{
throw new NotImplementedException();
return Out(tf.maximum(In(a), In(b, a.Shape)));
}

public SiaTensor Minimum(SiaTensor a, SiaTensor b)
{
throw new NotImplementedException();
return Out(tf.minimum(In(a), In(b)));
}

public SiaTensor Minimum(SiaTensor a, float b)
{
throw new NotImplementedException();
return Out(tf.minimum(In(a), In(b, a.Shape)));
}

public SiaTensor Transpose(SiaTensor x)
{
throw new NotImplementedException();
return Out(tf.transpose(In(x), new int[] { 1, 0 }));
}

public SiaTensor Transpose(SiaTensor x, params int[] dims)
{
throw new NotImplementedException();
return Out(tf.transpose(In(x), dims));
}

public SiaTensor Dot(SiaTensor a, SiaTensor b)
{
throw new NotImplementedException();
return Out(tf.matmul(In(a), In(b)));
}

public SiaTensor Diag(SiaTensor x)
Expand All @@ -523,17 +544,18 @@ public SiaTensor Diag(SiaTensor x)

public SiaTensor Softmax(SiaTensor x, int axis = -1)
{
throw new NotImplementedException();
return Out(tf.nn.softmax(In(x), axis));
}

public SiaTensor Softplus(SiaTensor x, int axis = -1)
{
throw new NotImplementedException();
return Log((Exp(x) + 1));
}

public SiaTensor L2Normalize(SiaTensor x, int axis = -1)
{
throw new NotImplementedException();
var y = Max(Sum(Square(x), axis), axis);
return x / Sqrt(y);
}

public SiaTensor Im2Col(SiaTensor x, Tuple<int, int> kernalSize, int padding = 1, int stride = 1)
Expand All @@ -558,17 +580,17 @@ public SiaTensor SliceCols(SiaTensor x, long start, long end)

public Array GetArray(SiaTensor x)
{
throw new NotImplementedException();
return In(x).Data<float>();
}

public void Dispose(SiaTensor x)
{
throw new NotImplementedException();
In(x).Dispose();
}

public ActivationFunc GetActFunc()
{
throw new NotImplementedException();
return new SiaNetActivations(this);
}
}
}
4 changes: 4 additions & 0 deletions Backends/SiaNet.Backend.Torch/SiaNet.Backend.Torch.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<None Include="TorchSharp.csproj" />
</ItemGroup>
Expand Down
38 changes: 37 additions & 1 deletion Examples/BackendTester/Program.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using NumSharp.Core;
using CNTK;
using NumSharp;
using SiaNet;
using SiaNet.Engine;
using SiaNet.Initializers;
using System;
using System.Collections.Generic;
using System.Linq;

namespace BackendTester
Expand All @@ -11,6 +13,15 @@ class Program
{
static void Main(string[] args)
{
int[] shape = new int[] { 6, 3 };
float[] data = new float[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9 };

NDArrayView array = new NDArrayView(shape, data, DeviceDescriptor.CPUDevice);
Variable variable = new Variable(shape, VariableKind.Parameter, CNTK.DataType.Float, array, false, new AxisVector(), false, "", "");
var slicedData = CNTKLib.Slice(variable, AxisVector.Repeat(new Axis(0), 1), IntVector.Repeat(1, 1), IntVector.Repeat(3, 1));
var resultArray = GetArray(slicedData);


Global.UseEngine(SiaNet.Backend.CNTKLib.SiaNetBackend.Instance, DeviceType.CPU);
var K = Global.CurrentBackend;

Expand All @@ -25,5 +36,30 @@ static void Main(string[] args)
Console.ReadLine();

}

public static Array GetArray(Variable xvar)
{
Value v = null;
if (xvar.IsOutput)
{
var f = xvar.ToFunction();

var plist = f.Parameters();
Dictionary<Variable, Value> inputs = new Dictionary<Variable, Value>();
Dictionary<Variable, Value> outputs = new Dictionary<Variable, Value>()
{
{ f, null}
};

f.Evaluate(inputs, outputs, DeviceDescriptor.CPUDevice);
v = outputs.FirstOrDefault().Value;
}
else
{
v = new Value(xvar.GetValue());
}

return v.GetDenseData<float>(xvar)[0].ToArray();
}
}
}
9 changes: 4 additions & 5 deletions Examples/BasicClassificationWithTitanicDataset/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ class Program
static void Main(string[] args)
{
//Setup Engine
Global.UseEngine(SiaNet.Backend.MxNetLib.SiaNetBackend.Instance, DeviceType.Default);
Global.UseEngine(SiaNet.Backend.ArrayFire.SiaNetBackend.Instance, DeviceType.Default);

var train = LoadTrain(); //Load train data
var test = LoadTest(); //Load test data

var model = new Sequential();
model.EpochEnd += Model_EpochEnd;
model.Add(new Dense(128, ActType.ReLU));
Expand All @@ -31,10 +31,9 @@ static void Main(string[] args)
model.Compile(OptimizerType.Adam, LossType.BinaryCrossEntropy, MetricType.BinaryAccurary);

// Perform training with train and val dataset
model.Train(train, epochs: 100, batchSize: 32);
model.Train(train, epochs: 100, batchSize: 200);

//var prediction = model.Predict(test);
//TOps.Round(prediction).Print();
var prediction = model.Predict(test);
}

private static void Model_EpochEnd(object sender, EpochEndEventArgs e)
Expand Down
2 changes: 1 addition & 1 deletion Examples/BostonHousingRegressionExample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Program
static void Main(string[] args)
{
//Setup Engine
Global.UseEngine(SiaNet.Backend.MxNetLib.SiaNetBackend.Instance, DeviceType.CPU);
Global.UseEngine(SiaNet.Backend.ArrayFire.SiaNetBackend.Instance, DeviceType.CPU);

//Load Train and Test CSV data
var ds = LoadTrain("./train.csv");
Expand Down
2 changes: 1 addition & 1 deletion Examples/GettingStarted/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ static void Main(string[] args)
predX.Load(0, 0, 0, 1); //Result should be 0 and 1

var rawPred = model.Predict(predX);
Global.CurrentBackend.Round(rawPred).Print();

Console.ReadLine();
}

Expand Down
Loading

0 comments on commit bf470be

Please sign in to comment.