Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions packages/typegpu/src/data/vectorOps.ts
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wgsl.MatchingBoolInstance<T> can also simplify select signature

Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,7 @@ export const VectorOps = {
<T extends wgsl.AnyVecInstance>(
e1: T,
e2: T,
) => T extends wgsl.AnyVec2Instance ? wgsl.v2b
: T extends wgsl.AnyVec3Instance ? wgsl.v3b
: wgsl.v4b
) => wgsl.MatchingBoolInstance<T>
>,

lt: {
Expand Down Expand Up @@ -272,9 +270,7 @@ export const VectorOps = {
<T extends wgsl.AnyNumericVecInstance>(
e1: T,
e2: T,
) => T extends wgsl.AnyVec2Instance ? wgsl.v2b
: T extends wgsl.AnyVec3Instance ? wgsl.v3b
: wgsl.v4b
) => wgsl.MatchingBoolInstance<T>
>,

or: {
Expand Down Expand Up @@ -1243,9 +1239,7 @@ export const VectorOps = {
<T extends wgsl.AnyVecInstance>(
f: T,
t: T,
c: T extends wgsl.AnyVec2Instance ? wgsl.v2b
: T extends wgsl.AnyVec3Instance ? wgsl.v3b
: wgsl.v4b,
c: wgsl.MatchingBoolInstance<T>,
) => T
>,

Expand Down
10 changes: 10 additions & 0 deletions packages/typegpu/src/data/wgslTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,16 @@ export type AnyIntegerVecInstance = v2i | v2u | v3i | v3u | v4i | v4u;

export type AnyBooleanVecInstance = v2b | v3b | v4b;

export type MatchingBoolInstance<T extends AnyVecInstance | number> = T extends
AnyVecInstance
? T['kind'] extends `vec${infer TDim extends 2 | 3 | 4}${string}` ? {
2: v2b;
3: v3b;
4: v4b;
}[TDim]
: never
: boolean;

export type AnySignedVecInstance =
| v2i
| v2f
Expand Down
150 changes: 114 additions & 36 deletions packages/typegpu/src/std/boolean.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ import {
type AnyVecInstance,
type AnyWgslData,
isVecInstance,
type MatchingBoolInstance,
type v2b,
type v3b,
type v4b,
} from '../data/wgslTypes.ts';
import { $internal } from '../shared/symbols.ts';
import { unify } from '../tgsl/conversion.ts';
import { sub } from './operators.ts';

function correspondingBooleanVectorSchema(dataType: AnyData) {
Expand Down Expand Up @@ -48,8 +50,20 @@ export const allEq = dualImpl({
codegenImpl: (lhs, rhs) => stitch`all(${lhs} == ${rhs})`,
});

const cpuEq = <T extends AnyVecInstance>(lhs: T, rhs: T) =>
VectorOps.eq[lhs.kind](lhs, rhs);
function cpuEq(lhs: number, rhs: number): boolean;
function cpuEq<T extends AnyVecInstance | number>(
lhs: T,
rhs: T,
): MatchingBoolInstance<T>;
function cpuEq<T extends AnyVecInstance | number>(
lhs: T,
rhs: T,
): MatchingBoolInstance<T> {
if (typeof lhs !== 'number' && typeof rhs !== 'number') {
return VectorOps.eq[lhs.kind](lhs, rhs) as MatchingBoolInstance<T>;
}
return (lhs === rhs) as MatchingBoolInstance<T>;
}

/**
* Checks **component-wise** whether `lhs == rhs`.
Expand All @@ -62,10 +76,13 @@ const cpuEq = <T extends AnyVecInstance>(lhs: T, rhs: T) =>
*/
export const eq = dualImpl({
name: 'eq',
signature: (...argTypes) => ({
argTypes,
returnType: correspondingBooleanVectorSchema(argTypes[0]),
}),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({
argTypes: uargs,
returnType: correspondingBooleanVectorSchema(uargs[0]),
});
},
normalImpl: cpuEq,
codegenImpl: (lhs, rhs) => stitch`(${lhs} == ${rhs})`,
});
Expand All @@ -89,23 +106,43 @@ export const ne = dualImpl({
codegenImpl: (lhs, rhs) => stitch`(${lhs} != ${rhs})`,
});

const cpuLt = <T extends AnyNumericVecInstance>(lhs: T, rhs: T) =>
VectorOps.lt[lhs.kind](lhs, rhs);
function cpuLt(lhs: number, rhs: number): boolean;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point we decided that if a plain js operator works, then we do not write overloads for std operators so as not to encourage users to complicate their code.
I assume we withdraw that decision to allow for overloaded functions, right? If so, the JSDocs would need updating

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not convinced about these new overloads. The following now becomes type safe, and it will generate invalid wgsl. It would be fine if we handled mixed comparisons though. Maybe we could make unify cast numbers to vectors by wrapping them?

function cmpLt(a: number | AnyNumericVecInstance, b: number | AnyNumericVecInstance) {
  "kernel";
  return lt(a, b);
}

cmpLt(3, vec2f());

function cpuLt<T extends AnyNumericVecInstance | number>(
lhs: T,
rhs: T,
): MatchingBoolInstance<T>;
function cpuLt<T extends AnyNumericVecInstance | number>(
lhs: T,
rhs: T,
): MatchingBoolInstance<T> {
if (typeof lhs !== 'number' && typeof rhs !== 'number') {
return VectorOps.lt[lhs.kind](lhs, rhs) as MatchingBoolInstance<T>;
}
return (lhs < rhs) as MatchingBoolInstance<T>;
}

/**
* Checks **component-wise** whether `lhs < rhs`.
* This function does **not** return `bool`, for that use-case, wrap the result in `all`.
* @example
* @example ```ts
* lt(vec2f(0.0, 0.0), vec2f(0.0, 1.0)) // returns vec2b(false, true)
* lt(vec3u(0, 1, 2), vec3u(2, 1, 0)) // returns vec3b(true, false, false)
* all(lt(vec4i(1, 2, 3, 4), vec4i(2, 3, 4, 5))) // returns true
* ```
*
* Also accepts scalar values for the sake of generic use, but it's better to use the `<` operator.
* @example ```ts
* lt(1, 2) // returns true
* ```
*/
export const lt = dualImpl({
name: 'lt',
signature: (...argTypes) => ({
argTypes,
returnType: correspondingBooleanVectorSchema(argTypes[0]),
}),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({
argTypes: uargs,
returnType: correspondingBooleanVectorSchema(uargs[0]),
});
},
normalImpl: cpuLt,
codegenImpl: (lhs, rhs) => stitch`(${lhs} < ${rhs})`,
});
Expand All @@ -120,15 +157,30 @@ export const lt = dualImpl({
*/
export const le = dualImpl({
name: 'le',
signature: (...argTypes) => ({
argTypes,
returnType: correspondingBooleanVectorSchema(argTypes[0]),
}),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({
argTypes: uargs,
returnType: correspondingBooleanVectorSchema(uargs[0]),
});
},
normalImpl: <T extends AnyNumericVecInstance>(lhs: T, rhs: T) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function's interface is now inconsistent with other boolean functions

cpuOr(cpuLt(lhs, rhs), cpuEq(lhs, rhs)),
codegenImpl: (lhs, rhs) => stitch`(${lhs} <= ${rhs})`,
});

function cpuGt(lhs: number, rhs: number): boolean;
function cpuGt<T extends AnyNumericVecInstance | number>(
lhs: T,
rhs: T,
): MatchingBoolInstance<T>;
function cpuGt<T extends AnyNumericVecInstance | number>(
lhs: T,
rhs: T,
): MatchingBoolInstance<T> {
return cpuAnd(cpuNot(cpuLt(lhs, rhs)), cpuNot(cpuEq(lhs, rhs)));
}

/**
* Checks **component-wise** whether `lhs > rhs`.
* This function does **not** return `bool`, for that use-case, wrap the result in `all`.
Expand All @@ -139,12 +191,14 @@ export const le = dualImpl({
*/
export const gt = dualImpl({
name: 'gt',
signature: (...argTypes) => ({
argTypes,
returnType: correspondingBooleanVectorSchema(argTypes[0]),
}),
normalImpl: <T extends AnyNumericVecInstance>(lhs: T, rhs: T) =>
cpuAnd(cpuNot(cpuLt(lhs, rhs)), cpuNot(cpuEq(lhs, rhs))),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({
argTypes: uargs,
returnType: correspondingBooleanVectorSchema(uargs[0]),
});
},
normalImpl: cpuGt,
codegenImpl: (lhs, rhs) => stitch`(${lhs} > ${rhs})`,
});

Expand All @@ -158,19 +212,28 @@ export const gt = dualImpl({
*/
export const ge = dualImpl({
name: 'ge',
signature: (...argTypes) => ({
argTypes: argTypes,
returnType: correspondingBooleanVectorSchema(argTypes[0]),
}),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({
argTypes: uargs,
returnType: correspondingBooleanVectorSchema(uargs[0]),
});
},
normalImpl: <T extends AnyNumericVecInstance>(lhs: T, rhs: T) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here as well

cpuNot(cpuLt(lhs, rhs)),
codegenImpl: (lhs, rhs) => stitch`(${lhs} >= ${rhs})`,
});

// logical ops

const cpuNot = <T extends AnyBooleanVecInstance>(value: T): T =>
VectorOps.neg[value.kind](value);
function cpuNot<T extends boolean>(value: boolean): boolean;
function cpuNot<T extends AnyBooleanVecInstance | boolean>(value: T): T;
function cpuNot<T extends AnyBooleanVecInstance | boolean>(value: T): T {
if (typeof value === 'boolean') {
return (!value) as T;
}
return VectorOps.neg[value.kind](value) as T;
}

/**
* Returns **component-wise** `!value`.
Expand All @@ -185,8 +248,14 @@ export const not = dualImpl({
codegenImpl: (arg) => stitch`!(${arg})`,
});

const cpuOr = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) =>
VectorOps.or[lhs.kind](lhs, rhs);
function cpuOr(lhs: boolean, rhs: boolean): boolean;
function cpuOr<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T;
function cpuOr<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T {
if (typeof lhs !== 'boolean' && typeof rhs !== 'boolean') {
return VectorOps.or[lhs.kind](lhs, rhs) as T;
}
return lhs || rhs;
}

/**
* Returns **component-wise** logical `or` result.
Expand All @@ -196,13 +265,19 @@ const cpuOr = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) =>
*/
export const or = dualImpl({
name: 'or',
signature: (...argTypes) => ({ argTypes, returnType: argTypes[0] }),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({ argTypes: uargs, returnType: uargs[0] });
},
normalImpl: cpuOr,
codegenImpl: (lhs, rhs) => stitch`(${lhs} | ${rhs})`,
});

const cpuAnd = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) =>
cpuNot(cpuOr(cpuNot(lhs), cpuNot(rhs)));
function cpuAnd(lhs: boolean, rhs: boolean): boolean;
function cpuAnd<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T;
function cpuAnd<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T {
return cpuNot(cpuOr(cpuNot(lhs), cpuNot(rhs)));
}

/**
* Returns **component-wise** logical `and` result.
Expand All @@ -212,7 +287,10 @@ const cpuAnd = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) =>
*/
export const and = dualImpl({
name: 'and',
signature: (...argTypes) => ({ argTypes, returnType: argTypes[0] }),
signature: (...args) => {
const uargs = unify(args) ?? args;
return ({ argTypes: uargs, returnType: uargs[0] });
},
normalImpl: cpuAnd,
codegenImpl: (lhs, rhs) => stitch`(${lhs} & ${rhs})`,
});
Expand Down
Loading