Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to Sentis 2.1 #18

Merged
merged 19 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BlazeDetectionSample/Face/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"com.unity.ide.visualstudio": "2.0.22",
"com.unity.inputsystem": "1.9.0",
"com.unity.render-pipelines.universal": "17.0.3",
"com.unity.sentis": "2.0.0",
"com.unity.sentis": "2.1.0",
"com.unity.test-framework": "1.4.4",
"com.unity.timeline": "1.8.7",
"com.unity.ugui": "2.0.0",
Expand Down
10 changes: 5 additions & 5 deletions BlazeDetectionSample/Face/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
"url": "https://packages.unity.com"
},
"com.unity.collections": {
"version": "2.4.1",
"depth": 2,
"version": "2.4.3",
"depth": 1,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.13",
Expand Down Expand Up @@ -137,12 +137,12 @@
"url": "https://packages.unity.com"
},
"com.unity.sentis": {
"version": "2.0.0",
"version": "2.1.0",
"depth": 0,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.12",
"com.unity.collections": "2.2.1",
"com.unity.burst": "1.8.17",
"com.unity.collections": "2.4.3",
"com.unity.modules.imageconversion": "1.0.0"
},
"url": "https://packages.unity.com"
Expand Down
2 changes: 1 addition & 1 deletion BlazeDetectionSample/Hand/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"com.unity.ide.visualstudio": "2.0.22",
"com.unity.inputsystem": "1.9.0",
"com.unity.render-pipelines.universal": "17.0.3",
"com.unity.sentis": "2.0.0",
"com.unity.sentis": "2.1.0",
"com.unity.test-framework": "1.4.4",
"com.unity.timeline": "1.8.7",
"com.unity.ugui": "2.0.0",
Expand Down
10 changes: 5 additions & 5 deletions BlazeDetectionSample/Hand/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
"url": "https://packages.unity.com"
},
"com.unity.collections": {
"version": "2.4.1",
"depth": 2,
"version": "2.4.3",
"depth": 1,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.13",
Expand Down Expand Up @@ -137,12 +137,12 @@
"url": "https://packages.unity.com"
},
"com.unity.sentis": {
"version": "2.0.0",
"version": "2.1.0",
"depth": 0,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.12",
"com.unity.collections": "2.2.1",
"com.unity.burst": "1.8.17",
"com.unity.collections": "2.4.3",
"com.unity.modules.imageconversion": "1.0.0"
},
"url": "https://packages.unity.com"
Expand Down
2 changes: 1 addition & 1 deletion BlazeDetectionSample/Pose/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"com.unity.ide.visualstudio": "2.0.22",
"com.unity.inputsystem": "1.9.0",
"com.unity.render-pipelines.universal": "17.0.3",
"com.unity.sentis": "2.0.0",
"com.unity.sentis": "2.1.0",
"com.unity.test-framework": "1.4.4",
"com.unity.timeline": "1.8.7",
"com.unity.ugui": "2.0.0",
Expand Down
10 changes: 5 additions & 5 deletions BlazeDetectionSample/Pose/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
"url": "https://packages.unity.com"
},
"com.unity.collections": {
"version": "2.4.1",
"depth": 2,
"version": "2.4.3",
"depth": 1,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.13",
Expand Down Expand Up @@ -137,12 +137,12 @@
"url": "https://packages.unity.com"
},
"com.unity.sentis": {
"version": "2.0.0",
"version": "2.1.0",
"depth": 0,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.12",
"com.unity.collections": "2.2.1",
"com.unity.burst": "1.8.17",
"com.unity.collections": "2.4.3",
"com.unity.modules.imageconversion": "1.0.0"
},
"url": "https://packages.unity.com"
Expand Down
17 changes: 14 additions & 3 deletions BoardGameAISample/Assets/Prefabs/Piece.prefab
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ Transform:
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 5233670122910542190}
serializedVersion: 2
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 1.776, y: 0, z: 0}
m_LocalScale: {x: 0.8, y: 0.3, z: 0.8}
m_ConstrainProportionsScale: 0
m_Children: []
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!33 &7037451497112201883
MeshFilter:
Expand All @@ -60,6 +60,9 @@ MeshRenderer:
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RayTracingAccelStructBuildFlagsOverride: 0
m_RayTracingAccelStructBuildFlags: 1
m_SmallMeshCulling: 1
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:
Expand Down Expand Up @@ -93,10 +96,18 @@ MeshCollider:
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 5233670122910542190}
m_Material: {fileID: 0}
m_IncludeLayers:
serializedVersion: 2
m_Bits: 0
m_ExcludeLayers:
serializedVersion: 2
m_Bits: 0
m_LayerOverridePriority: 0
m_IsTrigger: 0
m_ProvidesContacts: 0
m_Enabled: 1
serializedVersion: 4
m_Convex: 1
serializedVersion: 5
m_Convex: 0
m_CookingOptions: 30
m_Mesh: {fileID: 10207, guid: 0000000000000000e000000000000000, type: 0}
--- !u!114 &5346768096612052479
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ void GetRenderersActive(bool includeInactive)
bounds.Encapsulate(transform.TransformPoint(new Vector3(x, -y, -z)));
bounds.Encapsulate(transform.TransformPoint(new Vector3(-x, -y, -z)));

renderers = FindObjectsOfType<Renderer>(includeInactive).Where(r => {
renderers = FindObjectsByType<Renderer>(includeInactive ? FindObjectsInactive.Include : FindObjectsInactive.Exclude, FindObjectsSortMode.None).Where(r => {
return r.bounds.Intersects(bounds) && r.gameObject.name.Contains(objectsSearchFilter) && ((shaderSearchFilter == null) ? true : r.sharedMaterials.Any(m => m.shader == shaderSearchFilter));
}).ToArray();
}
Expand Down
90 changes: 44 additions & 46 deletions BoardGameAISample/Assets/Scripts/Othello.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.Sentis;
using UnityEngine.UI;
Expand All @@ -13,8 +11,7 @@ public class Othello : MonoBehaviour
{
// Sentis
montplaisir marked this conversation as resolved.
Show resolved Hide resolved
public ModelAsset model;
IWorker m_Engine;
Ops m_Ops;
Worker m_Engine;

// Board logic
public AudioClip pieceDown, illegalMoveBuzzer;
Expand Down Expand Up @@ -57,9 +54,9 @@ enum CameraMove { TO_BOARD, TO_OPPONENTS, NONE };
const int kRED = -1; // Spirit
int m_CurrentTurn = kBLACK; // keep track of who's turn it is to play

TensorFloat m_Data = TensorFloat.Zeros(new TensorShape(1, 1, kBoardDimension, kBoardDimension));
float[] m_LegalMoves = new float[kBoardDimension * kBoardDimension + 1];
TensorFloat m_MoveProbabilities = null;
Tensor<float> m_Data;
Tensor<float> m_LegalMoves;
Tensor<float> m_MoveProbabilities = null;
GameObject[] m_Pieces = new GameObject[kBoardDimension * kBoardDimension];

int m_PassesInARow = 0;
Expand All @@ -73,9 +70,30 @@ enum CameraMove { TO_BOARD, TO_OPPONENTS, NONE };
void Start()
{
// Load in the neural network that will make the move predictions for the spirit + create inference engine
m_Engine = WorkerFactory.CreateWorker(BackendType.CPU, ModelLoader.Load(model));
// ops for tensor handling
m_Ops = new CPUOps();
var othelloModel = ModelLoader.Load(model);

var graph = new FunctionalGraph();
var inputs = graph.AddInputs(othelloModel);
var outputs = Functional.Forward(othelloModel, inputs);
var boardState = outputs[0];
var bestMove = outputs[1];

// Ensure legal moves are considered when computing best move.
var legal = graph.AddInput(DataType.Float, new TensorShape(kBoardDimension * kBoardDimension + 1));
// Convert outputs to probabilities
bestMove = Functional.Exp(bestMove * m_AIDifficultyTemperature);
// Mask out illegal moves
bestMove = (0.0001f + bestMove) * legal;
// Normalize probabilities so they sum to 1
var redSum = Functional.ReduceSum(bestMove, new int[] { 1 }, true);
bestMove /= redSum;

var bestMoveModel = graph.Compile(boardState, bestMove);

m_Engine = new Worker(bestMoveModel, BackendType.CPU);

m_Data = new Tensor<float>(new TensorShape(1, 1, kBoardDimension, kBoardDimension));
m_LegalMoves = new Tensor<float>(new TensorShape(kBoardDimension * kBoardDimension + 1));

CreateBoard();
}
Expand Down Expand Up @@ -121,7 +139,6 @@ void NextMove()
ComputerMove();
}
}

}

Vector3 GetPiecePosition(int y, int x)
Expand Down Expand Up @@ -187,10 +204,10 @@ void ResetBoard()
m_Data[(kBoardDimension / 2 - 1), (kBoardDimension / 2)] = kBLACK;
m_Data[(kBoardDimension / 2), (kBoardDimension / 2 - 1)] = kBLACK;
m_Data[(kBoardDimension / 2 - 1), (kBoardDimension / 2 - 1)] = kRED;

m_PassesInARow = 0;
m_LastWinning = 0;

SetColors(kRED);
m_CurrentTurn = kBLACK;
SetSubtitle("Let's play. You begin.");
Expand Down Expand Up @@ -218,7 +235,7 @@ int SelectRandomMove()

void FlipBoard()
{
for (int i = 0; i < m_Data.shape.length; i++)
for (int i = 0; i < m_Data.shape.length; i++)
m_Data[i] *= -1;
}

Expand All @@ -228,42 +245,24 @@ void ComputerMove()
// The network is always form the point of view that the current player = 1 and opponent = -1
FlipBoard();

m_Engine.Execute(m_Data);
m_Engine.Schedule(m_Data, m_LegalMoves);

m_Data.MakeReadable();

// predict best move:
var bestMove = m_Engine.PeekOutput("best_move") as TensorFloat;
// estimate who is winning:
var boardState = m_Engine.PeekOutput("board_state") as TensorFloat;
using var boardState = (m_Engine.PeekOutput(0) as Tensor<float>).ReadbackAndClone();
// predict best move:
m_MoveProbabilities?.Dispose();
m_MoveProbabilities = (m_Engine.PeekOutput(1) as Tensor<float>).ReadbackAndClone();

boardState.MakeReadable();
float boardValue = boardState[0, 0];
bool blackIsWinning = -m_CurrentTurn * boardValue < 0;

//convert the boardValue [-1,1] into a more human readable number:
int percent = (int)(Mathf.Pow(Mathf.Abs(boardValue), 10f) * 50 + 50);

TensorFloat legal = new TensorFloat(new TensorShape(1, kBoardDimension * kBoardDimension + 1), m_LegalMoves);

DisplayPhrases(blackIsWinning, percent);

// Convert outputs to probabilities:
bestMove = m_Ops.Exp(m_Ops.Mul(bestMove, m_AIDifficultyTemperature));
// Mask out illegal moves:
bestMove = m_Ops.Mul(m_Ops.Add(0.0001f, bestMove), legal);
// Normalize probabilities so they sum to 1
bestMove = m_Ops.Div(bestMove, m_Ops.ReduceSum(bestMove, new int[] { 1 }, true));

bestMove.MakeReadable();

m_MoveProbabilities = bestMove;
m_MoveProbabilities.TakeOwnership();

DisplayProbabilities();

legal?.Dispose();

Invoke("MakeMove", m_PauseTime);
}

Expand Down Expand Up @@ -351,7 +350,7 @@ void ClearProbabilityDisplay()
}
}

void MakeMove()
void MakeMove()
{
ClearProbabilityDisplay();
int moveIndex = SelectRandomMove();
Expand Down Expand Up @@ -413,7 +412,7 @@ int FlipColors(int y, int x, int turn, bool checkonly=false)
int enemyPieces = 0;
// check for a line of enemy pieces in direction (dx,dy):
while (Y >= 0 && X >= 0 && X < kBoardDimension && Y < kBoardDimension && m_Data[Y * kBoardDimension + X] == -turn)
{
{
X += dx; Y += dy;
enemyPieces++;
}
Expand Down Expand Up @@ -490,7 +489,7 @@ private void Update()
float mouseY = Input.GetAxis("Mouse Y") * mouseSensititvy;

cameraAngleLR = cameraAngleLR + mouseX;

cameraAngleUp = Mathf.Clamp(cameraAngleUp - mouseY, -45, 45);
Camera.main.transform.localEulerAngles = new Vector3(cameraAngleUp, cameraAngleLR, 0);
}
Expand Down Expand Up @@ -521,7 +520,7 @@ void MouseClicked()
Ray ray = Camera.main.ScreenPointToRay(Input.mousePosition);
if (!Physics.Raycast(ray, out RaycastHit hit, 1000))
return;

GameObject go = hit.collider.gameObject;

if (HasParent(go.transform, "Opponent Easy"))
Expand All @@ -536,7 +535,7 @@ void MouseClicked()
{
LevelOptionSelected(kRED, 2);
}

int index = System.Array.IndexOf(m_Pieces, go);
if (index < 0)
return;
Expand All @@ -556,16 +555,15 @@ void MouseClicked()
{
GetComponent<AudioSource>().PlayOneShot(illegalMoveBuzzer);
Debug.Log("Can't go there");
}
}
}

private void OnApplicationQuit()
montplaisir marked this conversation as resolved.
Show resolved Hide resolved
{
CancelInvoke();
m_Engine?.Dispose();
m_MoveProbabilities.Dispose();
m_MoveProbabilities?.Dispose();
m_Data.Dispose();

m_Ops.Dispose();
m_LegalMoves.Dispose();
}
}
2 changes: 1 addition & 1 deletion BoardGameAISample/Packages/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"dependencies": {
"com.unity.ai.navigation": "2.0.0",
"com.unity.ide.visualstudio": "2.0.22",
"com.unity.sentis": "1.3.0-pre.3",
"com.unity.sentis": "2.1.0",
"com.unity.ugui": "2.0.0",
"com.unity.modules.accessibility": "1.0.0",
"com.unity.modules.ai": "1.0.0",
Expand Down
6 changes: 3 additions & 3 deletions BoardGameAISample/Packages/packages-lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
"url": "https://packages.unity.com"
},
"com.unity.sentis": {
"version": "1.3.0-pre.3",
"version": "2.1.0",
"depth": 0,
"source": "registry",
"dependencies": {
"com.unity.burst": "1.8.10",
"com.unity.modules.jsonserialize": "1.0.0",
"com.unity.burst": "1.8.17",
"com.unity.collections": "2.4.3",
"com.unity.modules.imageconversion": "1.0.0"
},
"url": "https://packages.unity.com"
Expand Down
4 changes: 2 additions & 2 deletions BoardGameAISample/ProjectSettings/ProjectVersion.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
m_EditorVersion: 2023.2.0b17
m_EditorVersionWithRevision: 2023.2.0b17 (1d22bd928c99)
m_EditorVersion: 6000.0.23f1
m_EditorVersionWithRevision: 6000.0.23f1 (2fd0cac8cdb0)
Loading