-
-
Notifications
You must be signed in to change notification settings - Fork 24
feat: Make std functions more generic #1723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
99065b5
d8240cb
b9d304a
328e46a
0f88c18
754965b
3786826
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,11 +14,13 @@ import { | |
type AnyVecInstance, | ||
type AnyWgslData, | ||
isVecInstance, | ||
type MatchingBoolInstance, | ||
type v2b, | ||
type v3b, | ||
type v4b, | ||
} from '../data/wgslTypes.ts'; | ||
import { $internal } from '../shared/symbols.ts'; | ||
import { unify } from '../tgsl/conversion.ts'; | ||
import { sub } from './operators.ts'; | ||
|
||
function correspondingBooleanVectorSchema(dataType: AnyData) { | ||
|
@@ -48,8 +50,20 @@ export const allEq = dualImpl({ | |
codegenImpl: (lhs, rhs) => stitch`all(${lhs} == ${rhs})`, | ||
}); | ||
|
||
const cpuEq = <T extends AnyVecInstance>(lhs: T, rhs: T) => | ||
VectorOps.eq[lhs.kind](lhs, rhs); | ||
function cpuEq(lhs: number, rhs: number): boolean; | ||
function cpuEq<T extends AnyVecInstance | number>( | ||
lhs: T, | ||
rhs: T, | ||
): MatchingBoolInstance<T>; | ||
function cpuEq<T extends AnyVecInstance | number>( | ||
lhs: T, | ||
rhs: T, | ||
): MatchingBoolInstance<T> { | ||
if (typeof lhs !== 'number' && typeof rhs !== 'number') { | ||
return VectorOps.eq[lhs.kind](lhs, rhs) as MatchingBoolInstance<T>; | ||
} | ||
return (lhs === rhs) as MatchingBoolInstance<T>; | ||
} | ||
|
||
/** | ||
* Checks **component-wise** whether `lhs == rhs`. | ||
|
@@ -62,10 +76,13 @@ const cpuEq = <T extends AnyVecInstance>(lhs: T, rhs: T) => | |
*/ | ||
export const eq = dualImpl({ | ||
name: 'eq', | ||
signature: (...argTypes) => ({ | ||
argTypes, | ||
returnType: correspondingBooleanVectorSchema(argTypes[0]), | ||
}), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ | ||
argTypes: uargs, | ||
returnType: correspondingBooleanVectorSchema(uargs[0]), | ||
}); | ||
}, | ||
normalImpl: cpuEq, | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} == ${rhs})`, | ||
}); | ||
|
@@ -89,23 +106,43 @@ export const ne = dualImpl({ | |
codegenImpl: (lhs, rhs) => stitch`(${lhs} != ${rhs})`, | ||
}); | ||
|
||
const cpuLt = <T extends AnyNumericVecInstance>(lhs: T, rhs: T) => | ||
VectorOps.lt[lhs.kind](lhs, rhs); | ||
function cpuLt(lhs: number, rhs: number): boolean; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At some point we decided that if a plain js operator works, then we do not write overloads for std operators so as not to encourage users to complicate their code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not convinced about these new overloads. The following now becomes type safe, and it will generate invalid wgsl. It would be fine if we handled mixed comparisons though. Maybe we could make function cmpLt(a: number | AnyNumericVecInstance, b: number | AnyNumericVecInstance) {
"kernel";
return lt(a, b);
}
cmpLt(3, vec2f()); |
||
function cpuLt<T extends AnyNumericVecInstance | number>( | ||
lhs: T, | ||
rhs: T, | ||
): MatchingBoolInstance<T>; | ||
function cpuLt<T extends AnyNumericVecInstance | number>( | ||
lhs: T, | ||
rhs: T, | ||
): MatchingBoolInstance<T> { | ||
if (typeof lhs !== 'number' && typeof rhs !== 'number') { | ||
return VectorOps.lt[lhs.kind](lhs, rhs) as MatchingBoolInstance<T>; | ||
} | ||
return (lhs < rhs) as MatchingBoolInstance<T>; | ||
} | ||
|
||
/** | ||
* Checks **component-wise** whether `lhs < rhs`. | ||
* This function does **not** return `bool`, for that use-case, wrap the result in `all`. | ||
* @example | ||
* @example ```ts | ||
* lt(vec2f(0.0, 0.0), vec2f(0.0, 1.0)) // returns vec2b(false, true) | ||
* lt(vec3u(0, 1, 2), vec3u(2, 1, 0)) // returns vec3b(true, false, false) | ||
* all(lt(vec4i(1, 2, 3, 4), vec4i(2, 3, 4, 5))) // returns true | ||
* ``` | ||
* | ||
* Also accepts scalar values for the sake of generic use, but it's better to use the `<` operator. | ||
* @example ```ts | ||
* lt(1, 2) // returns true | ||
* ``` | ||
*/ | ||
export const lt = dualImpl({ | ||
name: 'lt', | ||
signature: (...argTypes) => ({ | ||
argTypes, | ||
returnType: correspondingBooleanVectorSchema(argTypes[0]), | ||
}), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ | ||
argTypes: uargs, | ||
returnType: correspondingBooleanVectorSchema(uargs[0]), | ||
}); | ||
}, | ||
normalImpl: cpuLt, | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} < ${rhs})`, | ||
}); | ||
|
@@ -120,15 +157,30 @@ export const lt = dualImpl({ | |
*/ | ||
export const le = dualImpl({ | ||
name: 'le', | ||
signature: (...argTypes) => ({ | ||
argTypes, | ||
returnType: correspondingBooleanVectorSchema(argTypes[0]), | ||
}), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ | ||
argTypes: uargs, | ||
returnType: correspondingBooleanVectorSchema(uargs[0]), | ||
}); | ||
}, | ||
normalImpl: <T extends AnyNumericVecInstance>(lhs: T, rhs: T) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function's interface is now inconsistent with other boolean functions |
||
cpuOr(cpuLt(lhs, rhs), cpuEq(lhs, rhs)), | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} <= ${rhs})`, | ||
}); | ||
|
||
function cpuGt(lhs: number, rhs: number): boolean; | ||
function cpuGt<T extends AnyNumericVecInstance | number>( | ||
lhs: T, | ||
rhs: T, | ||
): MatchingBoolInstance<T>; | ||
function cpuGt<T extends AnyNumericVecInstance | number>( | ||
lhs: T, | ||
rhs: T, | ||
): MatchingBoolInstance<T> { | ||
return cpuAnd(cpuNot(cpuLt(lhs, rhs)), cpuNot(cpuEq(lhs, rhs))); | ||
} | ||
|
||
/** | ||
* Checks **component-wise** whether `lhs > rhs`. | ||
* This function does **not** return `bool`, for that use-case, wrap the result in `all`. | ||
|
@@ -139,12 +191,14 @@ export const le = dualImpl({ | |
*/ | ||
export const gt = dualImpl({ | ||
name: 'gt', | ||
signature: (...argTypes) => ({ | ||
argTypes, | ||
returnType: correspondingBooleanVectorSchema(argTypes[0]), | ||
}), | ||
normalImpl: <T extends AnyNumericVecInstance>(lhs: T, rhs: T) => | ||
cpuAnd(cpuNot(cpuLt(lhs, rhs)), cpuNot(cpuEq(lhs, rhs))), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ | ||
argTypes: uargs, | ||
returnType: correspondingBooleanVectorSchema(uargs[0]), | ||
}); | ||
}, | ||
normalImpl: cpuGt, | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} > ${rhs})`, | ||
}); | ||
|
||
|
@@ -158,19 +212,28 @@ export const gt = dualImpl({ | |
*/ | ||
export const ge = dualImpl({ | ||
name: 'ge', | ||
signature: (...argTypes) => ({ | ||
argTypes: argTypes, | ||
returnType: correspondingBooleanVectorSchema(argTypes[0]), | ||
}), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ | ||
argTypes: uargs, | ||
returnType: correspondingBooleanVectorSchema(uargs[0]), | ||
}); | ||
}, | ||
normalImpl: <T extends AnyNumericVecInstance>(lhs: T, rhs: T) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here as well |
||
cpuNot(cpuLt(lhs, rhs)), | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} >= ${rhs})`, | ||
}); | ||
|
||
// logical ops | ||
|
||
const cpuNot = <T extends AnyBooleanVecInstance>(value: T): T => | ||
VectorOps.neg[value.kind](value); | ||
function cpuNot<T extends boolean>(value: boolean): boolean; | ||
function cpuNot<T extends AnyBooleanVecInstance | boolean>(value: T): T; | ||
function cpuNot<T extends AnyBooleanVecInstance | boolean>(value: T): T { | ||
if (typeof value === 'boolean') { | ||
return (!value) as T; | ||
} | ||
return VectorOps.neg[value.kind](value) as T; | ||
} | ||
|
||
/** | ||
* Returns **component-wise** `!value`. | ||
|
@@ -185,8 +248,14 @@ export const not = dualImpl({ | |
codegenImpl: (arg) => stitch`!(${arg})`, | ||
}); | ||
|
||
const cpuOr = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) => | ||
VectorOps.or[lhs.kind](lhs, rhs); | ||
function cpuOr(lhs: boolean, rhs: boolean): boolean; | ||
function cpuOr<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T; | ||
function cpuOr<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T { | ||
if (typeof lhs !== 'boolean' && typeof rhs !== 'boolean') { | ||
return VectorOps.or[lhs.kind](lhs, rhs) as T; | ||
} | ||
return lhs || rhs; | ||
} | ||
|
||
/** | ||
* Returns **component-wise** logical `or` result. | ||
|
@@ -196,13 +265,19 @@ const cpuOr = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) => | |
*/ | ||
export const or = dualImpl({ | ||
name: 'or', | ||
signature: (...argTypes) => ({ argTypes, returnType: argTypes[0] }), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ argTypes: uargs, returnType: uargs[0] }); | ||
}, | ||
normalImpl: cpuOr, | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} | ${rhs})`, | ||
}); | ||
|
||
const cpuAnd = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) => | ||
cpuNot(cpuOr(cpuNot(lhs), cpuNot(rhs))); | ||
function cpuAnd(lhs: boolean, rhs: boolean): boolean; | ||
function cpuAnd<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T; | ||
function cpuAnd<T extends AnyBooleanVecInstance | boolean>(lhs: T, rhs: T): T { | ||
return cpuNot(cpuOr(cpuNot(lhs), cpuNot(rhs))); | ||
} | ||
|
||
/** | ||
* Returns **component-wise** logical `and` result. | ||
|
@@ -212,7 +287,10 @@ const cpuAnd = <T extends AnyBooleanVecInstance>(lhs: T, rhs: T) => | |
*/ | ||
export const and = dualImpl({ | ||
name: 'and', | ||
signature: (...argTypes) => ({ argTypes, returnType: argTypes[0] }), | ||
signature: (...args) => { | ||
const uargs = unify(args) ?? args; | ||
return ({ argTypes: uargs, returnType: uargs[0] }); | ||
}, | ||
normalImpl: cpuAnd, | ||
codegenImpl: (lhs, rhs) => stitch`(${lhs} & ${rhs})`, | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wgsl.MatchingBoolInstance<T>
can also simplifyselect
signature