From aaa7712dd80cd07dc9e5407868067ac3e79042b1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20Burzy=C5=84ski?= <mateuszburzynski@gmail.com>
Date: Thu, 7 Mar 2024 11:37:51 +0100
Subject: [PATCH] Implement guard insertions

---
 .../__tests__/source-edits/add-guard.test.ts  | 378 ++++++++++++++++++
 new-packages/ts-project/__tests__/utils.ts    |  20 +-
 new-packages/ts-project/src/index.ts          |  90 +++--
 new-packages/ts-project/src/utils.ts          |  30 +-
 4 files changed, 478 insertions(+), 40 deletions(-)
 create mode 100644 new-packages/ts-project/__tests__/source-edits/add-guard.test.ts

diff --git a/new-packages/ts-project/__tests__/source-edits/add-guard.test.ts b/new-packages/ts-project/__tests__/source-edits/add-guard.test.ts
new file mode 100644
index 00000000..83b91b89
--- /dev/null
+++ b/new-packages/ts-project/__tests__/source-edits/add-guard.test.ts
@@ -0,0 +1,378 @@
+import { expect, test } from 'vitest';
+import { createTestProject, testdir, ts } from '../utils';
+
+test('should be possible to add a guard to a transition', async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        on: {
+          FOO: "a",
+        },
+        states: {
+          a: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: [],
+      transitionPath: ['on', 'FOO', 0],
+      name: 'isItTooLate',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      on: {
+        FOO: {
+          target: "a",
+          guard: "isItTooLate"
+        },
+      },
+      states: {
+        a: {},
+      },
+    });",
+    }
+  `);
+});
+
+test('should be possible to add a guard to an object transition', async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        on: {
+          FOO: {
+            target: "a",
+          },
+        },
+        states: {
+          a: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: [],
+      transitionPath: ['on', 'FOO', 0],
+      name: 'isItTooLate',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      on: {
+        FOO: {
+          target: "a",
+          guard: "isItTooLate",
+        },
+      },
+      states: {
+        a: {},
+      },
+    });",
+    }
+  `);
+});
+
+test('should be possible to add a guard to the first transition for a given event', async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        on: {
+          FOO: ["a", "b", "c"],
+        },
+        states: {
+          a: {},
+          b: {},
+          c: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: [],
+      transitionPath: ['on', 'FOO', 0],
+      name: 'isItTooLate',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      on: {
+        FOO: [{
+          target: "a",
+          guard: "isItTooLate"
+        }, "b", "c"],
+      },
+      states: {
+        a: {},
+        b: {},
+        c: {},
+      },
+    });",
+    }
+  `);
+});
+
+test('should be possible to add a guard to the last transition for a given event', async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        on: {
+          FOO: ["a", "b", "c"],
+        },
+        states: {
+          a: {},
+          b: {},
+          c: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: [],
+      transitionPath: ['on', 'FOO', 2],
+      name: 'isItTooLate',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      on: {
+        FOO: ["a", "b", {
+          target: "c",
+          guard: "isItTooLate"
+        }],
+      },
+      states: {
+        a: {},
+        b: {},
+        c: {},
+      },
+    });",
+    }
+  `);
+});
+
+test('should be possible to add a guard to a middle transition for a given event', async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        on: {
+          FOO: ["a", "b", "c"],
+        },
+        states: {
+          a: {},
+          b: {},
+          c: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: [],
+      transitionPath: ['on', 'FOO', 1],
+      name: 'isItTooLate',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      on: {
+        FOO: ["a", {
+          target: "b",
+          guard: "isItTooLate"
+        }, "c"],
+      },
+      states: {
+        a: {},
+        b: {},
+        c: {},
+      },
+    });",
+    }
+  `);
+});
+
+test(`should be possible to add a guard to invoke's onDone`, async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        states: {
+          a: {
+            invoke: {
+              src: "callDavid",
+              onDone: "b",
+            },
+          },
+          b: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: ['a'],
+      transitionPath: ['invoke', 0, 'onDone', 0],
+      name: 'isHeBusy',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      states: {
+        a: {
+          invoke: {
+            src: "callDavid",
+            onDone: {
+              target: "b",
+              guard: "isHeBusy"
+            },
+          },
+        },
+        b: {},
+      },
+    });",
+    }
+  `);
+});
+
+test(`should be possible to add a guard to a transition defined for an empty event`, async () => {
+  const tmpPath = await testdir({
+    'tsconfig.json': JSON.stringify({}),
+    'index.ts': ts`
+      import { createMachine } from "xstate";
+
+      createMachine({
+        initial: "a",
+        states: {
+          a: {
+            on: {
+              "": "b",
+            },
+          },
+          b: {},
+        },
+      });
+    `,
+  });
+
+  const project = await createTestProject(tmpPath);
+
+  const textEdits = project.editDigraph(
+    {
+      fileName: 'index.ts',
+      machineIndex: 0,
+    },
+    {
+      type: 'add_guard',
+      path: ['a'],
+      transitionPath: ['on', '', 0],
+      name: 'isItHalfEmpty',
+    },
+  );
+  expect(await project.applyTextEdits(textEdits)).toMatchInlineSnapshot(`
+    {
+      "index.ts": "import { createMachine } from "xstate";
+
+    createMachine({
+      initial: "a",
+      states: {
+        a: {
+          on: {
+            "": {
+              target: "b",
+              guard: "isItHalfEmpty"
+            },
+          },
+        },
+        b: {},
+      },
+    });",
+    }
+  `);
+});
diff --git a/new-packages/ts-project/__tests__/utils.ts b/new-packages/ts-project/__tests__/utils.ts
index d16de6ad..bfae0abb 100644
--- a/new-packages/ts-project/__tests__/utils.ts
+++ b/new-packages/ts-project/__tests__/utils.ts
@@ -499,7 +499,25 @@ function produceNewDigraphUsingEdit(
     }
     case 'remove_action':
     case 'edit_action':
-    case 'add_guard':
+      throw new Error(`Not implemented: ${edit.type}`);
+    case 'add_guard': {
+      const eventTypeData = getEventTypeData(digraphDraft, {
+        sourcePath: edit.path,
+        transitionPath: edit.transitionPath,
+      });
+      const edge =
+        digraphDraft.edges[
+          getEdgeGroup(digraphDraft, eventTypeData)[
+            last(edit.transitionPath) as number
+          ]
+        ];
+      const block = createGuardBlock({
+        sourceId: edit.name,
+        parentId: edge.uniqueId,
+      });
+      registerGuardBlock(digraphDraft, block, edge);
+      break;
+    }
     case 'remove_guard':
     case 'edit_guard':
     case 'add_invoke':
diff --git a/new-packages/ts-project/src/index.ts b/new-packages/ts-project/src/index.ts
index 924911c8..d2a648da 100644
--- a/new-packages/ts-project/src/index.ts
+++ b/new-packages/ts-project/src/index.ts
@@ -528,15 +528,15 @@ function createProjectMachine({
                 // this might become a problem, especially when dealing with copy-pasting
                 // the implementation will have to account for that in the future
                 const newNode: Node = patch.value;
-                const parentNode = findNodeByAstPath(
+                const parentStateNode = findNodeByAstPath(
                   host.ts,
                   createMachineCall,
                   currentState.astPaths.nodes[newNode.parentId!],
                 );
-                assert(host.ts.isObjectLiteralExpression(parentNode));
+                assert(host.ts.isObjectLiteralExpression(parentStateNode));
 
                 codeChanges.insertAtOptionalObjectPath(
-                  parentNode,
+                  parentStateNode,
                   ['states', newNode.data.key],
                   c.object([]),
                   InsertionPriority.States,
@@ -549,6 +549,7 @@ function createProjectMachine({
                 if (patch.path.length > 2) {
                   if (patch.path[2] === 'data' && patch.path[3] === 'actions') {
                     deferredArrayPatches.push(patch);
+                    break;
                   }
                   break;
                 }
@@ -583,23 +584,24 @@ function createProjectMachine({
                   break;
                 }
                 const nodeId = patch.path[1];
-                const node = findNodeByAstPath(
+                const stateNode = findNodeByAstPath(
                   host.ts,
                   createMachineCall,
                   currentState.astPaths.nodes[nodeId],
                 );
-                assert(host.ts.isObjectLiteralExpression(node));
+                assert(host.ts.isObjectLiteralExpression(stateNode));
 
                 if (patch.path[2] === 'data' && patch.path[3] === 'key') {
-                  const parentNode = findNodeByAstPath(
+                  const parenStatetNode = findNodeByAstPath(
                     host.ts,
                     createMachineCall,
                     currentState.astPaths.nodes[nodeId].slice(0, -1),
                   );
-                  assert(host.ts.isObjectLiteralExpression(parentNode));
-                  const prop = parentNode.properties.find(
+                  assert(host.ts.isObjectLiteralExpression(parenStatetNode));
+                  const prop = parenStatetNode.properties.find(
                     (p): p is PropertyAssignment =>
-                      host.ts.isPropertyAssignment(p) && p.initializer === node,
+                      host.ts.isPropertyAssignment(p) &&
+                      p.initializer === stateNode,
                   )!;
                   codeChanges.replacePropertyName(prop, patch.value);
                   break;
@@ -608,7 +610,7 @@ function createProjectMachine({
                   const initialProp = findProperty(
                     undefined,
                     host.ts,
-                    node,
+                    stateNode,
                     'initial',
                   );
                   if (patch.value === undefined) {
@@ -633,7 +635,7 @@ function createProjectMachine({
                   const statesProp = findProperty(
                     undefined,
                     host.ts,
-                    node,
+                    stateNode,
                     'states',
                   );
 
@@ -647,7 +649,7 @@ function createProjectMachine({
                   }
 
                   codeChanges.insertPropertyIntoObject(
-                    node,
+                    stateNode,
                     'initial',
                     c.string(patch.value),
                     InsertionPriority.Initial,
@@ -657,7 +659,7 @@ function createProjectMachine({
                   const typeProp = findProperty(
                     undefined,
                     host.ts,
-                    node,
+                    stateNode,
                     'type',
                   );
                   if (patch.value === 'normal') {
@@ -679,7 +681,7 @@ function createProjectMachine({
                   }
 
                   codeChanges.insertPropertyIntoObject(
-                    node,
+                    stateNode,
                     'type',
                     c.string(patch.value),
                     InsertionPriority.StateType,
@@ -689,7 +691,7 @@ function createProjectMachine({
                   const historyProp = findProperty(
                     undefined,
                     host.ts,
-                    node,
+                    stateNode,
                     'history',
                   );
                   if (patch.value === undefined || patch.value === 'shallow') {
@@ -712,7 +714,7 @@ function createProjectMachine({
 
                   // TODO: insert it after the existing `type` property
                   codeChanges.insertPropertyIntoObject(
-                    node,
+                    stateNode,
                     'history',
                     c.string(patch.value),
                     InsertionPriority.History,
@@ -725,7 +727,7 @@ function createProjectMachine({
                   const descriptionProp = findProperty(
                     undefined,
                     host.ts,
-                    node,
+                    stateNode,
                     'description',
                   );
                   if (!patch.value) {
@@ -751,16 +753,42 @@ function createProjectMachine({
                   }
 
                   codeChanges.insertPropertyIntoObject(
-                    node,
+                    stateNode,
                     'description',
                     element,
                   );
                 }
+                break;
               case 'edges': {
                 if (patch.path[2] === 'data' && patch.path[3] === 'actions') {
                   deferredArrayPatches.push(patch);
                   break;
                 }
+                const edge = currentState.digraph!.edges[patch.path[1]];
+                const transitionNode = findNodeByAstPath(
+                  host.ts,
+                  createMachineCall,
+                  currentState.astPaths.edges[edge.uniqueId],
+                );
+                if (patch.path[2] === 'data' && patch.path[3] === 'guard') {
+                  const guardElement = c.string(
+                    currentState.digraph!.blocks[edge.data.guard!].sourceId,
+                  );
+                  if (!host.ts.isObjectLiteralExpression(transitionNode)) {
+                    codeChanges.wrapIntoObject(transitionNode, {
+                      reuseAs: 'target',
+                      newProperties: [c.property('guard', guardElement)],
+                    });
+                    break;
+                  }
+                  codeChanges.insertPropertyIntoObject(
+                    transitionNode,
+                    'guard',
+                    guardElement,
+                  );
+                  break;
+                }
+                break;
               }
             }
             break;
@@ -782,12 +810,12 @@ function createProjectMachine({
             switch (patch.path[0]) {
               case 'nodes': {
                 const nodeId = patch.path[1];
-                const node = findNodeByAstPath(
+                const stateNode = findNodeByAstPath(
                   host.ts,
                   createMachineCall,
                   currentState.astPaths.nodes[nodeId],
                 );
-                assert(host.ts.isObjectLiteralExpression(node));
+                assert(host.ts.isObjectLiteralExpression(stateNode));
 
                 if (
                   patch.path[2] === 'data' &&
@@ -808,7 +836,7 @@ function createProjectMachine({
                   assert(typeof actionId === 'string');
 
                   codeChanges.insertAtOptionalObjectPath(
-                    node,
+                    stateNode,
                     [patch.path[3], index],
                     c.string(currentState.digraph!.blocks[actionId].sourceId),
                   );
@@ -817,7 +845,7 @@ function createProjectMachine({
               }
               case 'edges': {
                 const edgeId = patch.path[1];
-                const edge = findNodeByAstPath(
+                const transitionNode = findNodeByAstPath(
                   host.ts,
                   createMachineCall,
                   currentState.astPaths.edges[edgeId],
@@ -836,10 +864,10 @@ function createProjectMachine({
                   const actionId = patch.value;
                   assert(typeof actionId === 'string');
 
-                  if (!host.ts.isObjectLiteralExpression(edge)) {
+                  if (!host.ts.isObjectLiteralExpression(transitionNode)) {
                     assert(index === 0);
 
-                    codeChanges.wrapIntoObject(edge, {
+                    codeChanges.wrapIntoObject(transitionNode, {
                       reuseAs: 'target',
                       newProperties: [
                         c.property(
@@ -853,7 +881,7 @@ function createProjectMachine({
                     break;
                   }
                   codeChanges.insertAtOptionalObjectPath(
-                    edge,
+                    transitionNode,
                     [patch.path[3], index],
                     c.string(currentState.digraph!.blocks[actionId].sourceId),
                   );
@@ -868,12 +896,12 @@ function createProjectMachine({
             switch (patch.path[0]) {
               case 'nodes': {
                 const nodeId = patch.path[1];
-                const node = findNodeByAstPath(
+                const stateNode = findNodeByAstPath(
                   host.ts,
                   createMachineCall,
                   currentState.astPaths.nodes[nodeId],
                 );
-                assert(host.ts.isObjectLiteralExpression(node));
+                assert(host.ts.isObjectLiteralExpression(stateNode));
                 if (
                   patch.path[2] === 'data' &&
                   (patch.path[3] === 'entry' || patch.path[3] === 'exit')
@@ -888,7 +916,7 @@ function createProjectMachine({
                     i += insertion.skipped;
 
                     codeChanges.insertAtOptionalObjectPath(
-                      node,
+                      stateNode,
                       [patch.path[3], insertion.index],
                       c.string(
                         currentState.digraph!.blocks[insertion.value].sourceId,
@@ -902,14 +930,14 @@ function createProjectMachine({
               }
               case 'edges': {
                 const edgeId = patch.path[1];
-                const edge = findNodeByAstPath(
+                const transitionNode = findNodeByAstPath(
                   host.ts,
                   createMachineCall,
                   currentState.astPaths.edges[edgeId],
                 );
                 if (patch.path[2] === 'data' && patch.path[3] === 'actions') {
                   // this should always be true - if we are replacing an action within an edge then the edge already has to be an object literal
-                  assert(host.ts.isObjectLiteralExpression(edge));
+                  assert(host.ts.isObjectLiteralExpression(transitionNode));
                   const insertion = consumeArrayInsertionAtIndex(
                     sortedArrayPatches,
                     i,
@@ -920,7 +948,7 @@ function createProjectMachine({
                     i += insertion.skipped;
 
                     codeChanges.insertAtOptionalObjectPath(
-                      edge,
+                      transitionNode,
                       [patch.path[3], insertion.index],
                       c.string(
                         currentState.digraph!.blocks[insertion.value].sourceId,
diff --git a/new-packages/ts-project/src/utils.ts b/new-packages/ts-project/src/utils.ts
index dc0c2519..9fad2f7d 100644
--- a/new-packages/ts-project/src/utils.ts
+++ b/new-packages/ts-project/src/utils.ts
@@ -239,7 +239,7 @@ export function forEachStaticProperty(
     }
     const key = getPropertyKey(ctx, ts, prop);
 
-    if (!key) {
+    if (typeof key !== 'string') {
       // error should already be reported by `getPropertyKey`
       continue;
     }
@@ -259,16 +259,30 @@ export function findNodeByAstPath(
   call: CallExpression,
   path: AstPath,
 ): Expression {
-  let current: Expression | undefined = call.arguments[0];
+  if (!call.arguments[0]) {
+    throw new Error('Invalid node');
+  }
+
+  let current = call.arguments[0];
+
   for (const segment of path) {
-    if (!current || !ts.isObjectLiteralExpression(current)) {
-      throw new Error('Invalid node');
+    if (ts.isObjectLiteralExpression(current)) {
+      const retrieved = current.properties[segment];
+      if (!retrieved || !ts.isPropertyAssignment(retrieved)) {
+        throw new Error('Invalid node');
+      }
+      current = retrieved.initializer;
+      continue;
     }
-    const retrieved: ObjectLiteralElementLike = current.properties[segment];
-    if (!retrieved || !ts.isPropertyAssignment(retrieved)) {
-      throw new Error('Invalid node');
+    if (ts.isArrayLiteralExpression(current)) {
+      const retrieved = current.elements[segment];
+      if (!retrieved) {
+        throw new Error('Invalid node');
+      }
+      current = retrieved;
+      continue;
     }
-    current = retrieved.initializer;
+    throw new Error('Invalid node');
   }
   return current;
 }