diff --git a/.github/workflows/db-migration-presubmit.yaml b/.github/workflows/db-migration-presubmit.yaml new file mode 100644 index 0000000..c8ccbfb --- /dev/null +++ b/.github/workflows/db-migration-presubmit.yaml @@ -0,0 +1,62 @@ +name: Database schema migration check +on: + # Run whenever code is changed in the main branch, + push: + branches: + - main + # Run on PRs where something changed under the `ent/migrate/migrations/` directory. + pull_request: + paths: + - 'ent/**' +jobs: + migration-check: + services: + # Spin up a postgres:10 container to be used as the dev-database for analysis. + postgres: + image: postgres:10 + env: + POSTGRES_DB: test + POSTGRES_PASSWORD: pass + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3.0.1 + with: + fetch-depth: 0 # Mandatory unless "latest" is set below. + + # doesn't seem to work - does not recognize migrations files. + # - uses: ariga/atlas-action@v1.0.11 + # with: + # dir: ent/migrate/migrations + # dir-format: golang-migrate # Or: atlas, goose, dbmate + # dev-url: postgres://postgres:pass@localhost:5432/test?sslmode=disable + + - name: Check for migration changes + id: check_migrations + run: | + # List files changed in the PR + echo "Checking for changes between ${{ github.event.pull_request.base.sha }} and ${{ github.sha }}" + changed_files=$(git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }}) + echo "Changed files: $changed_files" + + # Check for changes in 'ent/schema' + schema_changes=$(echo "$changed_files" | grep '^ent/schema' || true) + echo "Schema changes: $schema_changes" + + # Check for changes in 'ent/migrate/migrations' + migration_changes=$(echo "$changed_files" | grep '^ent/migrate/migrations' || true) + echo "Migration changes: $migration_changes" + + # If there are schema changes but no migration changes, fail the check + if [ -n "$schema_changes" ] && [ -z "$migration_changes" ]; then + echo "::error::Changes in 'ent/schema' require corresponding changes in 'ent/migrate/migrations'" + exit 1 + else + echo "Check passed: Schema changes are accompanied by migration changes." + fi \ No newline at end of file diff --git a/.github/workflows/logging-presubmit.yml b/.github/workflows/logging-presubmit.yml index 2ae114c..a609112 100644 --- a/.github/workflows/logging-presubmit.yml +++ b/.github/workflows/logging-presubmit.yml @@ -1,4 +1,4 @@ -name: semgrep +name: Logging check on: push: branches: diff --git a/.github/workflows/migration-ci.yaml b/.github/workflows/migration-ci.yaml deleted file mode 100644 index e3e44b1..0000000 --- a/.github/workflows/migration-ci.yaml +++ /dev/null @@ -1,37 +0,0 @@ -name: Atlas CI -on: - # Run whenever code is changed in the main branch, - # change this to your root branch. - # push: - # branches: - # - main - # Run on PRs where something changed under the `ent/migrate/migrations/` directory. - pull_request: - paths: - - 'ent/migrate/migrations/*' -jobs: - lint: - services: - # Spin up a postgres:10 container to be used as the dev-database for analysis. - postgres: - image: postgres:10 - env: - POSTGRES_DB: test - POSTGRES_PASSWORD: pass - ports: - - 5432:5432 - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3.0.1 - with: - fetch-depth: 0 # Mandatory unless "latest" is set below. - - uses: ariga/atlas-action@v0 - with: - dir: ent/migrate/migrations - dir-format: golang-migrate # Or: atlas, goose, dbmate - dev-url: postgres://postgres:pass@localhost:5432/test?sslmode=disable \ No newline at end of file diff --git a/.github/workflows/secret-scanning.yml b/.github/workflows/secret-scanning.yml new file mode 100644 index 0000000..d0c9ff1 --- /dev/null +++ b/.github/workflows/secret-scanning.yml @@ -0,0 +1,19 @@ +name: Secret Scanning +on: + push: + branches: + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main + with: + extra_args: --only-verified \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/README.md b/README.md index 9bdf674..5b98f44 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,23 @@ # registry-backend -The first service to receive API requests +The backend API server for [Comfy Registry](https://comfyregistry.org) and [Comfy CI/CD](https://comfyci.org). -## Local Dev +Join us at our discord: [https://discord.gg/comfycontrib](https://discord.gg/comfycontrib) + +Registry React Frontend [Github](https://github.com/Comfy-Org/registry-web) +Registry CLI [Github](https://github.com/yoland68/comfy-cli) + +## Local Development ### Golang -https://go.dev/doc/install +Install Golang: + + + +Install go packages + +`go get` ### Supabase @@ -31,9 +42,6 @@ These are needed for authenticating Firebase JWT token auth + calling other GCP When testing login with registry, use this: `gcloud config set project dreamboothy-dev` -When testing workspace / VM creation, use this: -`gcloud config set project dreamboothy` - `gcloud auth application-default login` If you are testing creating a node, you need to impersonate a service account because it requires signing cloud storage urls. @@ -60,7 +68,18 @@ This should search all directories and run go generate. This will run all the co Or manually run: -`go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert --feature sql/lock ./ent/schema` +`go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert --feature sql/lock --feature sql/modifier ./ent/schema` + +### Generate Migration Files + +Run this command to generate migration files needed for staging/prod database schema changes: + +```shell +atlas migrate diff migration \ + --dir "file://ent/migrate/migrations" \ + --to "ent://ent/schema" \ + --dev-url "docker://postgres/15/test?search_path=public" +``` ## API Spec Change (openapi.yml) @@ -74,7 +93,7 @@ Or manually run: `export PATH="$PATH:$HOME/bin:$HOME/go/bin"` -https://github.com/deepmap/oapi-codegen/issues/795 + `oapi-codegen --config drip/codegen.yaml openapi.yml` @@ -82,11 +101,19 @@ https://github.com/deepmap/oapi-codegen/issues/795 Here are some common errors and how to resolve them. +### Security Scan + +If you are calling the `security-scan` endpoint, you need to add the endpoint url to `docker-compose.yml` and then make sure you have the correct permissions to call that function. + +Check the `security-scan` Cloud Function repo for instructions on how to do that with `gcloud`. + +For non Comfy-Org contributors, you can use your own hosted function or just avoid touching this part. We keep the security scan code private to avoid exploiters taking advantage of it. + ### Firebase Token Errors Usually in localdev, we use dreamboothy-dev Firebase project for authentication. This conflicts with our machine creation logic because all of those machine images are in dreamboothy. TODO(robinhuang): Figure out a solution for this. Either we replicate things in dreamboothy-dev, or we pass project information separately when creating machine images. -### Creating VM instance error: +### Creating VM instance error **Example:** @@ -131,23 +158,23 @@ In order to bypass authentication error, you can add make the following changes package drip_middleware func FirebaseMiddleware(entClient *ent.Client) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - userDetails := &UserDetails{ - ID: "test-james-token-id", - Email: "test-james-email@gmail.com", - Name: "James", - } - - authdCtx := context.WithValue(ctx.Request().Context(), UserContextKey, userDetails) - ctx.SetRequest(ctx.Request().WithContext(authdCtx)) - newUserError := db.UpsertUser(ctx.Request().Context(), entClient, userDetails.ID, userDetails.Email, userDetails.Name) - if newUserError != nil { - log.Ctx(ctx).Info().Ctx(ctx.Request().Context()).Err(newUserError).Msg("error User upserted successfully.") - } - return next(ctx) - } - } + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(ctx echo.Context) error { + userDetails := &UserDetails{ + ID: "test-james-token-id", + Email: "test-james-email@gmail.com", + Name: "James", + } + + authdCtx := context.WithValue(ctx.Request().Context(), UserContextKey, userDetails) + ctx.SetRequest(ctx.Request().WithContext(authdCtx)) + newUserError := db.UpsertUser(ctx.Request().Context(), entClient, userDetails.ID, userDetails.Email, userDetails.Name) + if newUserError != nil { + log.Ctx(ctx).Info().Ctx(ctx.Request().Context()).Err(newUserError).Msg("error User upserted successfully.") + } + return next(ctx) + } + } } ``` diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 34e379c..d7efb8d 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -2,9 +2,38 @@ steps: # build the container image - name: "gcr.io/cloud-builders/docker" args: ["build", "-t", "us-central1-docker.pkg.dev/dreamboothy/registry-backend/registry-backend-image:$SHORT_SHA", "."] - # push container image + + # push container image - name: "gcr.io/cloud-builders/docker" args: ["push", "us-central1-docker.pkg.dev/dreamboothy/registry-backend/registry-backend-image:$SHORT_SHA"] + + # Clone the GitHub repository + - name: "gcr.io/cloud-builders/git" + args: [ "clone", "https://github.com/Comfy-Org/registry-backend.git", "registry-backend" ] + dir: "/workspace" + + # Run database migrations for staging + - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" + entrypoint: "bash" + args: + - "-c" + - | + curl -sSL https://atlasgo.sh | sh + atlas migrate apply --dir "file://ent/migrate/migrations" --url $$STAGING_DB_CONNECTION_STRING + secretEnv: ['STAGING_DB_CONNECTION_STRING'] + dir: "/workspace/registry-backend" + + # Run database migrations for prod + - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" + entrypoint: "bash" + args: + - "-c" + - | + curl -sSL https://atlasgo.sh | sh + atlas migrate apply --dir "file://ent/migrate/migrations" --url $$PROD_SUPABASE_CONNECTION_STRING + secretEnv: [ 'PROD_SUPABASE_CONNECTION_STRING' ] + dir: "/workspace/registry-backend" + # Publish the release - name: 'gcr.io/google.com/cloudsdktool/cloud-sdk:458.0.1' entrypoint: 'bash' @@ -16,5 +45,13 @@ steps: --region=us-central1 --delivery-pipeline=comfy-backend-api-pipeline --images=registry-backend-image-substitute=us-central1-docker.pkg.dev/dreamboothy/registry-backend/registry-backend-image:$SHORT_SHA + +availableSecrets: + secretManager: + - versionName: projects/357148958219/secrets/STAGING_SUPABASE_CONNECTION_STRING/versions/latest + env: 'STAGING_DB_CONNECTION_STRING' + - versionName: projects/357148958219/secrets/PROD_SUPABASE_CONNECTION_STRING/versions/latest + env: 'PROD_SUPABASE_CONNECTION_STRING' + options: machineType: 'E2_HIGHCPU_8' \ No newline at end of file diff --git a/config/config.go b/config/config.go index 3508977..c7e949e 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,10 @@ package config type Config struct { - ProjectID string - DripEnv string + ProjectID string + DripEnv string + SlackRegistryChannelWebhook string + JWTSecret string + DiscordSecurityChannelWebhook string + SecretScannerURL string } diff --git a/docker-compose.yml b/docker-compose.yml index 58d76c0..439bc47 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,4 +16,8 @@ services: GOOGLE_CLOUD_PROJECT: "dreamboothy-dev" # This will be set in prod by GCP. PROJECT_ID: "dreamboothy-dev" CORS_ORIGIN: "http://localhost:3000" + JWT_SECRET: 8zT9YknYUTZRVAkgov86gT1NLezTtwrd # test secret LOG_LEVEL: info # Set the log level here + ALGOLIA_APP_ID: + ALGOLIA_API_KEY: + SECRET_SCANNER_URL: "" diff --git a/drip/api.gen.go b/drip/api.gen.go index eaa9c8d..1fd45d8 100644 --- a/drip/api.gen.go +++ b/drip/api.gen.go @@ -27,11 +27,52 @@ const ( BearerAuthScopes = "BearerAuth.Scopes" ) +// Defines values for NodeStatus. +const ( + NodeStatusActive NodeStatus = "NodeStatusActive" + NodeStatusBanned NodeStatus = "NodeStatusBanned" + NodeStatusDeleted NodeStatus = "NodeStatusDeleted" +) + +// Defines values for NodeVersionStatus. +const ( + NodeVersionStatusActive NodeVersionStatus = "NodeVersionStatusActive" + NodeVersionStatusBanned NodeVersionStatus = "NodeVersionStatusBanned" + NodeVersionStatusDeleted NodeVersionStatus = "NodeVersionStatusDeleted" + NodeVersionStatusFlagged NodeVersionStatus = "NodeVersionStatusFlagged" + NodeVersionStatusPending NodeVersionStatus = "NodeVersionStatusPending" +) + +// Defines values for PublisherStatus. +const ( + PublisherStatusActive PublisherStatus = "PublisherStatusActive" + PublisherStatusBanned PublisherStatus = "PublisherStatusBanned" +) + +// Defines values for WorkflowRunStatus. +const ( + WorkflowRunStatusCompleted WorkflowRunStatus = "WorkflowRunStatusCompleted" + WorkflowRunStatusFailed WorkflowRunStatus = "WorkflowRunStatusFailed" + WorkflowRunStatusStarted WorkflowRunStatus = "WorkflowRunStatusStarted" +) + // ActionJobResult defines model for ActionJobResult. type ActionJobResult struct { + // ActionJobId Identifier of the job this result belongs to + ActionJobId *string `json:"action_job_id,omitempty"` + // ActionRunId Identifier of the run this result belongs to ActionRunId *string `json:"action_run_id,omitempty"` + // Author The author of the commit + Author *string `json:"author,omitempty"` + + // AvgVram The average VRAM used by the job + AvgVram *int `json:"avg_vram,omitempty"` + + // ComfyRunFlags The comfy run flags. E.g. `--low-vram` + ComfyRunFlags *string `json:"comfy_run_flags,omitempty"` + // CommitHash The hash of the commit CommitHash *string `json:"commit_hash,omitempty"` @@ -44,27 +85,41 @@ type ActionJobResult struct { // CommitTime The Unix timestamp when the commit was made CommitTime *int64 `json:"commit_time,omitempty"` + // CudaVersion CUDA version used + CudaVersion *string `json:"cuda_version,omitempty"` + // EndTime The end time of the job as a Unix timestamp. EndTime *int64 `json:"end_time,omitempty"` // GitRepo The repository name GitRepo *string `json:"git_repo,omitempty"` - // GpuType GPU type used - GpuType *string `json:"gpu_type,omitempty"` - // Id Unique identifier for the job result Id *openapi_types.UUID `json:"id,omitempty"` + // JobTriggerUser The user who triggered the job. + JobTriggerUser *string `json:"job_trigger_user,omitempty"` + MachineStats *MachineStats `json:"machine_stats,omitempty"` + // OperatingSystem Operating system used OperatingSystem *string `json:"operating_system,omitempty"` + // PeakVram The peak VRAM used by the job + PeakVram *int `json:"peak_vram,omitempty"` + + // PrNumber The pull request number + PrNumber *string `json:"pr_number,omitempty"` + + // PythonVersion PyTorch version used + PythonVersion *string `json:"python_version,omitempty"` + // PytorchVersion PyTorch version used PytorchVersion *string `json:"pytorch_version,omitempty"` // StartTime The start time of the job as a Unix timestamp. - StartTime *int64 `json:"start_time,omitempty"` - StorageFile *StorageFile `json:"storage_file,omitempty"` + StartTime *int64 `json:"start_time,omitempty"` + Status *WorkflowRunStatus `json:"status,omitempty"` + StorageFile *StorageFile `json:"storage_file,omitempty"` // WorkflowName Name of the workflow WorkflowName *string `json:"workflow_name,omitempty"` @@ -85,9 +140,48 @@ type ErrorResponse struct { Message string `json:"message"` } +// MachineStats defines model for MachineStats. +type MachineStats struct { + // CpuCapacity Total CPU on the machine. + CpuCapacity *string `json:"cpu_capacity,omitempty"` + + // DiskCapacity Total disk capacity on the machine. + DiskCapacity *string `json:"disk_capacity,omitempty"` + + // GpuType The GPU type. eg. NVIDIA Tesla K80 + GpuType *string `json:"gpu_type,omitempty"` + + // InitialCpu Initial CPU available before the job starts. + InitialCpu *string `json:"initial_cpu,omitempty"` + + // InitialDisk Initial disk available before the job starts. + InitialDisk *string `json:"initial_disk,omitempty"` + + // InitialRam Initial RAM available before the job starts. + InitialRam *string `json:"initial_ram,omitempty"` + + // MachineName Name of the machine. + MachineName *string `json:"machine_name,omitempty"` + + // MemoryCapacity Total memory on the machine. + MemoryCapacity *string `json:"memory_capacity,omitempty"` + + // OsVersion The operating system version. eg. Ubuntu Linux 20.04 + OsVersion *string `json:"os_version,omitempty"` + + // PipFreeze The pip freeze output + PipFreeze *string `json:"pip_freeze,omitempty"` + + // VramTimeSeries Time series of VRAM usage. + VramTimeSeries *map[string]interface{} `json:"vram_time_series,omitempty"` +} + // Node defines model for Node. type Node struct { - Author *string `json:"author,omitempty"` + Author *string `json:"author,omitempty"` + + // Category The category of the node. + Category *string `json:"category,omitempty"` Description *string `json:"description,omitempty"` // Downloads The number of downloads of the node. @@ -111,10 +205,17 @@ type Node struct { Rating *float32 `json:"rating,omitempty"` // Repository URL to the node's repository. - Repository *string `json:"repository,omitempty"` - Tags *[]string `json:"tags,omitempty"` + Repository *string `json:"repository,omitempty"` + Status *NodeStatus `json:"status,omitempty"` + + // StatusDetail The status detail of the node. + StatusDetail *string `json:"status_detail,omitempty"` + Tags *[]string `json:"tags,omitempty"` } +// NodeStatus defines model for NodeStatus. +type NodeStatus string + // NodeVersion defines model for NodeVersion. type NodeVersion struct { // Changelog Summary of changes made in this version @@ -130,13 +231,20 @@ type NodeVersion struct { Deprecated *bool `json:"deprecated,omitempty"` // DownloadUrl [Output Only] URL to download this version of the node - DownloadUrl *string `json:"downloadUrl,omitempty"` - Id *string `json:"id,omitempty"` + DownloadUrl *string `json:"downloadUrl,omitempty"` + Id *string `json:"id,omitempty"` + Status *NodeVersionStatus `json:"status,omitempty"` + + // StatusReason The reason for the status change. + StatusReason *string `json:"status_reason,omitempty"` // Version The version identifier, following semantic versioning. Must be unique for the node. Version *string `json:"version,omitempty"` } +// NodeVersionStatus defines model for NodeVersionStatus. +type NodeVersionStatus string + // NodeVersionUpdateRequest defines model for NodeVersionUpdateRequest. type NodeVersionUpdateRequest struct { // Changelog The changelog describing the version changes. @@ -180,6 +288,7 @@ type Publisher struct { Members *[]PublisherMember `json:"members,omitempty"` Name *string `json:"name,omitempty"` SourceCodeRepo *string `json:"source_code_repo,omitempty"` + Status *PublisherStatus `json:"status,omitempty"` Support *string `json:"support,omitempty"` Website *string `json:"website,omitempty"` } @@ -194,6 +303,9 @@ type PublisherMember struct { User *PublisherUser `json:"user,omitempty"` } +// PublisherStatus defines model for PublisherStatus. +type PublisherStatus string + // PublisherUser defines model for PublisherUser. type PublisherUser struct { // Email The email address for this user. @@ -236,6 +348,17 @@ type User struct { Name *string `json:"name,omitempty"` } +// WorkflowRunStatus defines model for WorkflowRunStatus. +type WorkflowRunStatus string + +// AdminUpdateNodeVersionJSONBody defines parameters for AdminUpdateNodeVersion. +type AdminUpdateNodeVersionJSONBody struct { + Status *NodeVersionStatus `json:"status,omitempty"` + + // StatusReason The reason for the status change. + StatusReason *string `json:"status_reason,omitempty"` +} + // GetBranchParams defines parameters for GetBranch. type GetBranchParams struct { // RepoName The repo to filter by. @@ -273,6 +396,24 @@ type ListAllNodesParams struct { // Limit Number of nodes to return per page Limit *int `form:"limit,omitempty" json:"limit,omitempty"` + + // IncludeBanned Number of nodes to return per page + IncludeBanned *bool `form:"include_banned,omitempty" json:"include_banned,omitempty"` +} + +// SearchNodesParams defines parameters for SearchNodes. +type SearchNodesParams struct { + // Page Page number of the nodes list + Page *int `form:"page,omitempty" json:"page,omitempty"` + + // Limit Number of nodes to return per page + Limit *int `form:"limit,omitempty" json:"limit,omitempty"` + + // Search Keyword to search the nodes + Search *string `form:"search,omitempty" json:"search,omitempty"` + + // IncludeBanned Number of nodes to return per page + IncludeBanned *bool `form:"include_banned,omitempty" json:"include_banned,omitempty"` } // InstallNodeParams defines parameters for InstallNode. @@ -281,12 +422,29 @@ type InstallNodeParams struct { Version *string `form:"version,omitempty" json:"version,omitempty"` } +// PostNodeReviewParams defines parameters for PostNodeReview. +type PostNodeReviewParams struct { + // Star number of star given to the node version + Star int `form:"star" json:"star"` +} + +// ListNodeVersionsParams defines parameters for ListNodeVersions. +type ListNodeVersionsParams struct { + Statuses *[]NodeVersionStatus `form:"statuses,omitempty" json:"statuses,omitempty"` +} + // ValidatePublisherParams defines parameters for ValidatePublisher. type ValidatePublisherParams struct { // Username The publisher username to validate. Username string `form:"username" json:"username"` } +// ListNodesForPublisherParams defines parameters for ListNodesForPublisher. +type ListNodesForPublisherParams struct { + // IncludeBanned Number of nodes to return per page + IncludeBanned *bool `form:"include_banned,omitempty" json:"include_banned,omitempty"` +} + // PublishNodeVersionJSONBody defines parameters for PublishNodeVersion. type PublishNodeVersionJSONBody struct { Node Node `json:"node"` @@ -294,8 +452,19 @@ type PublishNodeVersionJSONBody struct { PersonalAccessToken string `json:"personal_access_token"` } +// SecurityScanParams defines parameters for SecurityScan. +type SecurityScanParams struct { + MinAge *time.Duration `form:"minAge,omitempty" json:"minAge,omitempty"` + MaxNodes *int `form:"maxNodes,omitempty" json:"maxNodes,omitempty"` +} + // PostUploadArtifactJSONBody defines parameters for PostUploadArtifact. type PostUploadArtifactJSONBody struct { + // Author The author of the commit + Author string `json:"author"` + + // AvgVram The average amount of VRAM used in the run. + AvgVram *int `json:"avg_vram,omitempty"` BranchName string `json:"branch_name"` // BucketName The name of the bucket where the output files are stored @@ -303,7 +472,10 @@ type PostUploadArtifactJSONBody struct { // ComfyLogsGcsPath The path to ComfyUI logs. eg. gs://bucket-name/logs ComfyLogsGcsPath *string `json:"comfy_logs_gcs_path,omitempty"` - CommitHash string `json:"commit_hash"` + + // ComfyRunFlags The flags used in the comfy run + ComfyRunFlags *string `json:"comfy_run_flags,omitempty"` + CommitHash string `json:"commit_hash"` // CommitMessage The commit message CommitMessage string `json:"commit_message"` @@ -320,12 +492,28 @@ type PostUploadArtifactJSONBody struct { // JobId Unique identifier for the job JobId string `json:"job_id"` + // JobTriggerUser The user who triggered the job + JobTriggerUser string `json:"job_trigger_user"` + MachineStats *MachineStats `json:"machine_stats,omitempty"` + // Os Operating system used in the run Os string `json:"os"` // OutputFilesGcsPaths A comma separated string that contains GCS path(s) to output files. eg. gs://bucket-name/output, gs://bucket-name/output2 OutputFilesGcsPaths *string `json:"output_files_gcs_paths,omitempty"` + // PeakVram The peak amount of VRAM used in the run. + PeakVram *int `json:"peak_vram,omitempty"` + + // PrNumber The pull request number + PrNumber string `json:"pr_number"` + + // PythonVersion The python version used in the run + PythonVersion string `json:"python_version"` + + // PytorchVersion The pytorch version used in the run + PytorchVersion *string `json:"pytorch_version,omitempty"` + // Repo Repository name Repo string `json:"repo"` @@ -333,12 +521,28 @@ type PostUploadArtifactJSONBody struct { RunId string `json:"run_id"` // StartTime The start time of the job as a Unix timestamp. - StartTime int64 `json:"start_time"` + StartTime int64 `json:"start_time"` + Status WorkflowRunStatus `json:"status"` // WorkflowName The name of the workflow WorkflowName string `json:"workflow_name"` } +// ListAllNodeVersionsParams defines parameters for ListAllNodeVersions. +type ListAllNodeVersionsParams struct { + NodeId *string `form:"nodeId,omitempty" json:"nodeId,omitempty"` + Statuses *[]NodeVersionStatus `form:"statuses,omitempty" json:"statuses,omitempty"` + + // Page The page number to retrieve. + Page *int `form:"page,omitempty" json:"page,omitempty"` + + // PageSize The number of items to include per page. + PageSize *int `form:"pageSize,omitempty" json:"pageSize,omitempty"` +} + +// AdminUpdateNodeVersionJSONRequestBody defines body for AdminUpdateNodeVersion for application/json ContentType. +type AdminUpdateNodeVersionJSONRequestBody AdminUpdateNodeVersionJSONBody + // CreatePublisherJSONRequestBody defines body for CreatePublisher for application/json ContentType. type CreatePublisherJSONRequestBody = Publisher @@ -365,6 +569,9 @@ type PostUploadArtifactJSONRequestBody PostUploadArtifactJSONBody // ServerInterface represents all server handlers. type ServerInterface interface { + // Admin Update Node Version Status + // (PUT /admin/nodes/{nodeId}/versions/{versionNumber}) + AdminUpdateNodeVersion(ctx echo.Context, nodeId string, versionNumber string) error // Retrieve all distinct branches for a given repo // (GET /branch) GetBranch(ctx echo.Context, params GetBranchParams) error @@ -374,15 +581,24 @@ type ServerInterface interface { // Retrieves a list of nodes // (GET /nodes) ListAllNodes(ctx echo.Context, params ListAllNodesParams) error + // Reindex all nodes for searching. + // (POST /nodes/reindex) + ReindexNodes(ctx echo.Context) error + // Retrieves a list of nodes + // (GET /nodes/search) + SearchNodes(ctx echo.Context, params SearchNodesParams) error // Retrieve a specific node by ID // (GET /nodes/{nodeId}) GetNode(ctx echo.Context, nodeId string) error // Returns a node version to be installed. // (GET /nodes/{nodeId}/install) InstallNode(ctx echo.Context, nodeId string, params InstallNodeParams) error + // Add review to a specific version of a node + // (POST /nodes/{nodeId}/reviews) + PostNodeReview(ctx echo.Context, nodeId string, params PostNodeReviewParams) error // List all versions of a node // (GET /nodes/{nodeId}/versions) - ListNodeVersions(ctx echo.Context, nodeId string) error + ListNodeVersions(ctx echo.Context, nodeId string, params ListNodeVersionsParams) error // Retrieve a specific version of a node // (GET /nodes/{nodeId}/versions/{versionId}) GetNodeVersion(ctx echo.Context, nodeId string, versionId string) error @@ -404,9 +620,12 @@ type ServerInterface interface { // Update a publisher // (PUT /publishers/{publisherId}) UpdatePublisher(ctx echo.Context, publisherId string) error + // Ban a publisher + // (POST /publishers/{publisherId}/ban) + BanPublisher(ctx echo.Context, publisherId string) error // Retrieve all nodes // (GET /publishers/{publisherId}/nodes) - ListNodesForPublisher(ctx echo.Context, publisherId string) error + ListNodesForPublisher(ctx echo.Context, publisherId string, params ListNodesForPublisherParams) error // Create a new custom node // (POST /publishers/{publisherId}/nodes) CreateNode(ctx echo.Context, publisherId string) error @@ -416,6 +635,9 @@ type ServerInterface interface { // Update a specific node // (PUT /publishers/{publisherId}/nodes/{nodeId}) UpdateNode(ctx echo.Context, publisherId string, nodeId string) error + // Ban a publisher's Node + // (POST /publishers/{publisherId}/nodes/{nodeId}/ban) + BanPublisherNode(ctx echo.Context, publisherId string, nodeId string) error // Retrieve permissions the user has for a given publisher // (GET /publishers/{publisherId}/nodes/{nodeId}/permissions) GetPermissionOnPublisherNodes(ctx echo.Context, publisherId string, nodeId string) error @@ -440,6 +662,9 @@ type ServerInterface interface { // Delete a specific personal access token // (DELETE /publishers/{publisherId}/tokens/{tokenId}) DeletePersonalAccessToken(ctx echo.Context, publisherId string, tokenId string) error + // Security Scan + // (GET /security-scan) + SecurityScan(ctx echo.Context, params SecurityScanParams) error // Receive artifacts (output files) from the ComfyUI GitHub Action // (POST /upload-artifact) PostUploadArtifact(ctx echo.Context) error @@ -449,6 +674,12 @@ type ServerInterface interface { // Retrieve all publishers for a given user // (GET /users/publishers/) ListPublishersForUser(ctx echo.Context) error + // List all node versions given some filters. + // (GET /versions) + ListAllNodeVersions(ctx echo.Context, params ListAllNodeVersionsParams) error + // Retrieve a specific commit by ID + // (GET /workflowresult/{workflowResultId}) + GetWorkflowResult(ctx echo.Context, workflowResultId string) error } // ServerInterfaceWrapper converts echo contexts to parameters. @@ -456,6 +687,32 @@ type ServerInterfaceWrapper struct { Handler ServerInterface } +// AdminUpdateNodeVersion converts echo context to params. +func (w *ServerInterfaceWrapper) AdminUpdateNodeVersion(ctx echo.Context) error { + var err error + // ------------- Path parameter "nodeId" ------------- + var nodeId string + + err = runtime.BindStyledParameterWithOptions("simple", "nodeId", ctx.Param("nodeId"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter nodeId: %s", err)) + } + + // ------------- Path parameter "versionNumber" ------------- + var versionNumber string + + err = runtime.BindStyledParameterWithOptions("simple", "versionNumber", ctx.Param("versionNumber"), &versionNumber, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter versionNumber: %s", err)) + } + + ctx.Set(BearerAuthScopes, []string{}) + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.AdminUpdateNodeVersion(ctx, nodeId, versionNumber) + return err +} + // GetBranch converts echo context to params. func (w *ServerInterfaceWrapper) GetBranch(ctx echo.Context) error { var err error @@ -554,11 +811,66 @@ func (w *ServerInterfaceWrapper) ListAllNodes(ctx echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter limit: %s", err)) } + // ------------- Optional query parameter "include_banned" ------------- + + err = runtime.BindQueryParameter("form", true, false, "include_banned", ctx.QueryParams(), ¶ms.IncludeBanned) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter include_banned: %s", err)) + } + // Invoke the callback with all the unmarshaled arguments err = w.Handler.ListAllNodes(ctx, params) return err } +// ReindexNodes converts echo context to params. +func (w *ServerInterfaceWrapper) ReindexNodes(ctx echo.Context) error { + var err error + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.ReindexNodes(ctx) + return err +} + +// SearchNodes converts echo context to params. +func (w *ServerInterfaceWrapper) SearchNodes(ctx echo.Context) error { + var err error + + // Parameter object where we will unmarshal all parameters from the context + var params SearchNodesParams + // ------------- Optional query parameter "page" ------------- + + err = runtime.BindQueryParameter("form", true, false, "page", ctx.QueryParams(), ¶ms.Page) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter page: %s", err)) + } + + // ------------- Optional query parameter "limit" ------------- + + err = runtime.BindQueryParameter("form", true, false, "limit", ctx.QueryParams(), ¶ms.Limit) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter limit: %s", err)) + } + + // ------------- Optional query parameter "search" ------------- + + err = runtime.BindQueryParameter("form", true, false, "search", ctx.QueryParams(), ¶ms.Search) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter search: %s", err)) + } + + // ------------- Optional query parameter "include_banned" ------------- + + err = runtime.BindQueryParameter("form", true, false, "include_banned", ctx.QueryParams(), ¶ms.IncludeBanned) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter include_banned: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.SearchNodes(ctx, params) + return err +} + // GetNode converts echo context to params. func (w *ServerInterfaceWrapper) GetNode(ctx echo.Context) error { var err error @@ -600,6 +912,31 @@ func (w *ServerInterfaceWrapper) InstallNode(ctx echo.Context) error { return err } +// PostNodeReview converts echo context to params. +func (w *ServerInterfaceWrapper) PostNodeReview(ctx echo.Context) error { + var err error + // ------------- Path parameter "nodeId" ------------- + var nodeId string + + err = runtime.BindStyledParameterWithOptions("simple", "nodeId", ctx.Param("nodeId"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter nodeId: %s", err)) + } + + // Parameter object where we will unmarshal all parameters from the context + var params PostNodeReviewParams + // ------------- Required query parameter "star" ------------- + + err = runtime.BindQueryParameter("form", true, true, "star", ctx.QueryParams(), ¶ms.Star) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter star: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.PostNodeReview(ctx, nodeId, params) + return err +} + // ListNodeVersions converts echo context to params. func (w *ServerInterfaceWrapper) ListNodeVersions(ctx echo.Context) error { var err error @@ -611,8 +948,17 @@ func (w *ServerInterfaceWrapper) ListNodeVersions(ctx echo.Context) error { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter nodeId: %s", err)) } + // Parameter object where we will unmarshal all parameters from the context + var params ListNodeVersionsParams + // ------------- Optional query parameter "statuses" ------------- + + err = runtime.BindQueryParameter("form", true, false, "statuses", ctx.QueryParams(), ¶ms.Statuses) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter statuses: %s", err)) + } + // Invoke the callback with all the unmarshaled arguments - err = w.Handler.ListNodeVersions(ctx, nodeId) + err = w.Handler.ListNodeVersions(ctx, nodeId, params) return err } @@ -730,6 +1076,22 @@ func (w *ServerInterfaceWrapper) UpdatePublisher(ctx echo.Context) error { return err } +// BanPublisher converts echo context to params. +func (w *ServerInterfaceWrapper) BanPublisher(ctx echo.Context) error { + var err error + // ------------- Path parameter "publisherId" ------------- + var publisherId string + + err = runtime.BindStyledParameterWithOptions("simple", "publisherId", ctx.Param("publisherId"), &publisherId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter publisherId: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.BanPublisher(ctx, publisherId) + return err +} + // ListNodesForPublisher converts echo context to params. func (w *ServerInterfaceWrapper) ListNodesForPublisher(ctx echo.Context) error { var err error @@ -743,8 +1105,17 @@ func (w *ServerInterfaceWrapper) ListNodesForPublisher(ctx echo.Context) error { ctx.Set(BearerAuthScopes, []string{}) + // Parameter object where we will unmarshal all parameters from the context + var params ListNodesForPublisherParams + // ------------- Optional query parameter "include_banned" ------------- + + err = runtime.BindQueryParameter("form", true, false, "include_banned", ctx.QueryParams(), ¶ms.IncludeBanned) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter include_banned: %s", err)) + } + // Invoke the callback with all the unmarshaled arguments - err = w.Handler.ListNodesForPublisher(ctx, publisherId) + err = w.Handler.ListNodesForPublisher(ctx, publisherId, params) return err } @@ -818,6 +1189,30 @@ func (w *ServerInterfaceWrapper) UpdateNode(ctx echo.Context) error { return err } +// BanPublisherNode converts echo context to params. +func (w *ServerInterfaceWrapper) BanPublisherNode(ctx echo.Context) error { + var err error + // ------------- Path parameter "publisherId" ------------- + var publisherId string + + err = runtime.BindStyledParameterWithOptions("simple", "publisherId", ctx.Param("publisherId"), &publisherId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter publisherId: %s", err)) + } + + // ------------- Path parameter "nodeId" ------------- + var nodeId string + + err = runtime.BindStyledParameterWithOptions("simple", "nodeId", ctx.Param("nodeId"), &nodeId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter nodeId: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.BanPublisherNode(ctx, publisherId, nodeId) + return err +} + // GetPermissionOnPublisherNodes converts echo context to params. func (w *ServerInterfaceWrapper) GetPermissionOnPublisherNodes(ctx echo.Context) error { var err error @@ -1014,6 +1409,31 @@ func (w *ServerInterfaceWrapper) DeletePersonalAccessToken(ctx echo.Context) err return err } +// SecurityScan converts echo context to params. +func (w *ServerInterfaceWrapper) SecurityScan(ctx echo.Context) error { + var err error + + // Parameter object where we will unmarshal all parameters from the context + var params SecurityScanParams + // ------------- Optional query parameter "minAge" ------------- + + err = runtime.BindQueryParameter("form", true, false, "minAge", ctx.QueryParams(), ¶ms.MinAge) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter minAge: %s", err)) + } + + // ------------- Optional query parameter "maxNodes" ------------- + + err = runtime.BindQueryParameter("form", true, false, "maxNodes", ctx.QueryParams(), ¶ms.MaxNodes) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter maxNodes: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.SecurityScan(ctx, params) + return err +} + // PostUploadArtifact converts echo context to params. func (w *ServerInterfaceWrapper) PostUploadArtifact(ctx echo.Context) error { var err error @@ -1043,51 +1463,112 @@ func (w *ServerInterfaceWrapper) ListPublishersForUser(ctx echo.Context) error { return err } -// This is a simple interface which specifies echo.Route addition functions which -// are present on both echo.Echo and echo.Group, since we want to allow using -// either of them for path registration -type EchoRouter interface { - CONNECT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - HEAD(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - OPTIONS(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - PATCH(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - TRACE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route -} +// ListAllNodeVersions converts echo context to params. +func (w *ServerInterfaceWrapper) ListAllNodeVersions(ctx echo.Context) error { + var err error -// RegisterHandlers adds each server route to the EchoRouter. -func RegisterHandlers(router EchoRouter, si ServerInterface) { - RegisterHandlersWithBaseURL(router, si, "") -} + // Parameter object where we will unmarshal all parameters from the context + var params ListAllNodeVersionsParams + // ------------- Optional query parameter "nodeId" ------------- -// Registers handlers, and prepends BaseURL to the paths, so that the paths -// can be served under a prefix. -func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) { + err = runtime.BindQueryParameter("form", true, false, "nodeId", ctx.QueryParams(), ¶ms.NodeId) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter nodeId: %s", err)) + } - wrapper := ServerInterfaceWrapper{ - Handler: si, + // ------------- Optional query parameter "statuses" ------------- + + err = runtime.BindQueryParameter("form", true, false, "statuses", ctx.QueryParams(), ¶ms.Statuses) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter statuses: %s", err)) } - router.GET(baseURL+"/branch", wrapper.GetBranch) - router.GET(baseURL+"/gitcommit", wrapper.GetGitcommit) - router.GET(baseURL+"/nodes", wrapper.ListAllNodes) - router.GET(baseURL+"/nodes/:nodeId", wrapper.GetNode) - router.GET(baseURL+"/nodes/:nodeId/install", wrapper.InstallNode) - router.GET(baseURL+"/nodes/:nodeId/versions", wrapper.ListNodeVersions) - router.GET(baseURL+"/nodes/:nodeId/versions/:versionId", wrapper.GetNodeVersion) - router.GET(baseURL+"/publishers", wrapper.ListPublishers) + // ------------- Optional query parameter "page" ------------- + + err = runtime.BindQueryParameter("form", true, false, "page", ctx.QueryParams(), ¶ms.Page) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter page: %s", err)) + } + + // ------------- Optional query parameter "pageSize" ------------- + + err = runtime.BindQueryParameter("form", true, false, "pageSize", ctx.QueryParams(), ¶ms.PageSize) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter pageSize: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.ListAllNodeVersions(ctx, params) + return err +} + +// GetWorkflowResult converts echo context to params. +func (w *ServerInterfaceWrapper) GetWorkflowResult(ctx echo.Context) error { + var err error + // ------------- Path parameter "workflowResultId" ------------- + var workflowResultId string + + err = runtime.BindStyledParameterWithOptions("simple", "workflowResultId", ctx.Param("workflowResultId"), &workflowResultId, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationPath, Explode: false, Required: true}) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Invalid format for parameter workflowResultId: %s", err)) + } + + // Invoke the callback with all the unmarshaled arguments + err = w.Handler.GetWorkflowResult(ctx, workflowResultId) + return err +} + +// This is a simple interface which specifies echo.Route addition functions which +// are present on both echo.Echo and echo.Group, since we want to allow using +// either of them for path registration +type EchoRouter interface { + CONNECT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + DELETE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + HEAD(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + OPTIONS(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + PATCH(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + PUT(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + TRACE(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route +} + +// RegisterHandlers adds each server route to the EchoRouter. +func RegisterHandlers(router EchoRouter, si ServerInterface) { + RegisterHandlersWithBaseURL(router, si, "") +} + +// Registers handlers, and prepends BaseURL to the paths, so that the paths +// can be served under a prefix. +func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL string) { + + wrapper := ServerInterfaceWrapper{ + Handler: si, + } + + router.PUT(baseURL+"/admin/nodes/:nodeId/versions/:versionNumber", wrapper.AdminUpdateNodeVersion) + router.GET(baseURL+"/branch", wrapper.GetBranch) + router.GET(baseURL+"/gitcommit", wrapper.GetGitcommit) + router.GET(baseURL+"/nodes", wrapper.ListAllNodes) + router.POST(baseURL+"/nodes/reindex", wrapper.ReindexNodes) + router.GET(baseURL+"/nodes/search", wrapper.SearchNodes) + router.GET(baseURL+"/nodes/:nodeId", wrapper.GetNode) + router.GET(baseURL+"/nodes/:nodeId/install", wrapper.InstallNode) + router.POST(baseURL+"/nodes/:nodeId/reviews", wrapper.PostNodeReview) + router.GET(baseURL+"/nodes/:nodeId/versions", wrapper.ListNodeVersions) + router.GET(baseURL+"/nodes/:nodeId/versions/:versionId", wrapper.GetNodeVersion) + router.GET(baseURL+"/publishers", wrapper.ListPublishers) router.POST(baseURL+"/publishers", wrapper.CreatePublisher) router.GET(baseURL+"/publishers/validate", wrapper.ValidatePublisher) router.DELETE(baseURL+"/publishers/:publisherId", wrapper.DeletePublisher) router.GET(baseURL+"/publishers/:publisherId", wrapper.GetPublisher) router.PUT(baseURL+"/publishers/:publisherId", wrapper.UpdatePublisher) + router.POST(baseURL+"/publishers/:publisherId/ban", wrapper.BanPublisher) router.GET(baseURL+"/publishers/:publisherId/nodes", wrapper.ListNodesForPublisher) router.POST(baseURL+"/publishers/:publisherId/nodes", wrapper.CreateNode) router.DELETE(baseURL+"/publishers/:publisherId/nodes/:nodeId", wrapper.DeleteNode) router.PUT(baseURL+"/publishers/:publisherId/nodes/:nodeId", wrapper.UpdateNode) + router.POST(baseURL+"/publishers/:publisherId/nodes/:nodeId/ban", wrapper.BanPublisherNode) router.GET(baseURL+"/publishers/:publisherId/nodes/:nodeId/permissions", wrapper.GetPermissionOnPublisherNodes) router.POST(baseURL+"/publishers/:publisherId/nodes/:nodeId/versions", wrapper.PublishNodeVersion) router.DELETE(baseURL+"/publishers/:publisherId/nodes/:nodeId/versions/:versionId", wrapper.DeleteNodeVersion) @@ -1096,10 +1577,76 @@ func RegisterHandlersWithBaseURL(router EchoRouter, si ServerInterface, baseURL router.GET(baseURL+"/publishers/:publisherId/tokens", wrapper.ListPersonalAccessTokens) router.POST(baseURL+"/publishers/:publisherId/tokens", wrapper.CreatePersonalAccessToken) router.DELETE(baseURL+"/publishers/:publisherId/tokens/:tokenId", wrapper.DeletePersonalAccessToken) + router.GET(baseURL+"/security-scan", wrapper.SecurityScan) router.POST(baseURL+"/upload-artifact", wrapper.PostUploadArtifact) router.GET(baseURL+"/users", wrapper.GetUser) router.GET(baseURL+"/users/publishers/", wrapper.ListPublishersForUser) + router.GET(baseURL+"/versions", wrapper.ListAllNodeVersions) + router.GET(baseURL+"/workflowresult/:workflowResultId", wrapper.GetWorkflowResult) + +} + +type AdminUpdateNodeVersionRequestObject struct { + NodeId string `json:"nodeId"` + VersionNumber string `json:"versionNumber"` + Body *AdminUpdateNodeVersionJSONRequestBody +} + +type AdminUpdateNodeVersionResponseObject interface { + VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error +} + +type AdminUpdateNodeVersion200JSONResponse NodeVersion + +func (response AdminUpdateNodeVersion200JSONResponse) VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + + return json.NewEncoder(w).Encode(response) +} + +type AdminUpdateNodeVersion400JSONResponse ErrorResponse +func (response AdminUpdateNodeVersion400JSONResponse) VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(400) + + return json.NewEncoder(w).Encode(response) +} + +type AdminUpdateNodeVersion401Response struct { +} + +func (response AdminUpdateNodeVersion401Response) VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error { + w.WriteHeader(401) + return nil +} + +type AdminUpdateNodeVersion403JSONResponse ErrorResponse + +func (response AdminUpdateNodeVersion403JSONResponse) VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + +type AdminUpdateNodeVersion404JSONResponse ErrorResponse + +func (response AdminUpdateNodeVersion404JSONResponse) VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + + return json.NewEncoder(w).Encode(response) +} + +type AdminUpdateNodeVersion500JSONResponse ErrorResponse + +func (response AdminUpdateNodeVersion500JSONResponse) VisitAdminUpdateNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) } type GetBranchRequestObject struct { @@ -1228,6 +1775,94 @@ func (response ListAllNodes500JSONResponse) VisitListAllNodesResponse(w http.Res return json.NewEncoder(w).Encode(response) } +type ReindexNodesRequestObject struct { +} + +type ReindexNodesResponseObject interface { + VisitReindexNodesResponse(w http.ResponseWriter) error +} + +type ReindexNodes200Response struct { +} + +func (response ReindexNodes200Response) VisitReindexNodesResponse(w http.ResponseWriter) error { + w.WriteHeader(200) + return nil +} + +type ReindexNodes400JSONResponse ErrorResponse + +func (response ReindexNodes400JSONResponse) VisitReindexNodesResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(400) + + return json.NewEncoder(w).Encode(response) +} + +type ReindexNodes500JSONResponse ErrorResponse + +func (response ReindexNodes500JSONResponse) VisitReindexNodesResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + +type SearchNodesRequestObject struct { + Params SearchNodesParams +} + +type SearchNodesResponseObject interface { + VisitSearchNodesResponse(w http.ResponseWriter) error +} + +type SearchNodes200JSONResponse struct { + // Limit Maximum number of nodes per page + Limit *int `json:"limit,omitempty"` + Nodes *[]Node `json:"nodes,omitempty"` + + // Page Current page number + Page *int `json:"page,omitempty"` + + // Total Total number of nodes available + Total *int `json:"total,omitempty"` + + // TotalPages Total number of pages available + TotalPages *int `json:"totalPages,omitempty"` +} + +func (response SearchNodes200JSONResponse) VisitSearchNodesResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + + return json.NewEncoder(w).Encode(response) +} + +type SearchNodes400Response struct { +} + +func (response SearchNodes400Response) VisitSearchNodesResponse(w http.ResponseWriter) error { + w.WriteHeader(400) + return nil +} + +type SearchNodes404Response struct { +} + +func (response SearchNodes404Response) VisitSearchNodesResponse(w http.ResponseWriter) error { + w.WriteHeader(404) + return nil +} + +type SearchNodes500JSONResponse ErrorResponse + +func (response SearchNodes500JSONResponse) VisitSearchNodesResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + type GetNodeRequestObject struct { NodeId string `json:"nodeId"` } @@ -1326,8 +1961,53 @@ func (response InstallNode500JSONResponse) VisitInstallNodeResponse(w http.Respo return json.NewEncoder(w).Encode(response) } +type PostNodeReviewRequestObject struct { + NodeId string `json:"nodeId"` + Params PostNodeReviewParams +} + +type PostNodeReviewResponseObject interface { + VisitPostNodeReviewResponse(w http.ResponseWriter) error +} + +type PostNodeReview200JSONResponse Node + +func (response PostNodeReview200JSONResponse) VisitPostNodeReviewResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + + return json.NewEncoder(w).Encode(response) +} + +type PostNodeReview400Response struct { +} + +func (response PostNodeReview400Response) VisitPostNodeReviewResponse(w http.ResponseWriter) error { + w.WriteHeader(400) + return nil +} + +type PostNodeReview404JSONResponse Error + +func (response PostNodeReview404JSONResponse) VisitPostNodeReviewResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + + return json.NewEncoder(w).Encode(response) +} + +type PostNodeReview500JSONResponse ErrorResponse + +func (response PostNodeReview500JSONResponse) VisitPostNodeReviewResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + type ListNodeVersionsRequestObject struct { NodeId string `json:"nodeId"` + Params ListNodeVersionsParams } type ListNodeVersionsResponseObject interface { @@ -1343,6 +2023,15 @@ func (response ListNodeVersions200JSONResponse) VisitListNodeVersionsResponse(w return json.NewEncoder(w).Encode(response) } +type ListNodeVersions403JSONResponse ErrorResponse + +func (response ListNodeVersions403JSONResponse) VisitListNodeVersionsResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + type ListNodeVersions404JSONResponse Error func (response ListNodeVersions404JSONResponse) VisitListNodeVersionsResponse(w http.ResponseWriter) error { @@ -1643,8 +2332,60 @@ func (response UpdatePublisher500JSONResponse) VisitUpdatePublisherResponse(w ht return json.NewEncoder(w).Encode(response) } +type BanPublisherRequestObject struct { + PublisherId string `json:"publisherId"` +} + +type BanPublisherResponseObject interface { + VisitBanPublisherResponse(w http.ResponseWriter) error +} + +type BanPublisher204Response struct { +} + +func (response BanPublisher204Response) VisitBanPublisherResponse(w http.ResponseWriter) error { + w.WriteHeader(204) + return nil +} + +type BanPublisher401Response struct { +} + +func (response BanPublisher401Response) VisitBanPublisherResponse(w http.ResponseWriter) error { + w.WriteHeader(401) + return nil +} + +type BanPublisher403JSONResponse ErrorResponse + +func (response BanPublisher403JSONResponse) VisitBanPublisherResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + +type BanPublisher404JSONResponse ErrorResponse + +func (response BanPublisher404JSONResponse) VisitBanPublisherResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + + return json.NewEncoder(w).Encode(response) +} + +type BanPublisher500JSONResponse ErrorResponse + +func (response BanPublisher500JSONResponse) VisitBanPublisherResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + type ListNodesForPublisherRequestObject struct { PublisherId string `json:"publisherId"` + Params ListNodesForPublisherParams } type ListNodesForPublisherResponseObject interface { @@ -1829,6 +2570,58 @@ func (response UpdateNode500JSONResponse) VisitUpdateNodeResponse(w http.Respons return json.NewEncoder(w).Encode(response) } +type BanPublisherNodeRequestObject struct { + PublisherId string `json:"publisherId"` + NodeId string `json:"nodeId"` +} + +type BanPublisherNodeResponseObject interface { + VisitBanPublisherNodeResponse(w http.ResponseWriter) error +} + +type BanPublisherNode204Response struct { +} + +func (response BanPublisherNode204Response) VisitBanPublisherNodeResponse(w http.ResponseWriter) error { + w.WriteHeader(204) + return nil +} + +type BanPublisherNode401Response struct { +} + +func (response BanPublisherNode401Response) VisitBanPublisherNodeResponse(w http.ResponseWriter) error { + w.WriteHeader(401) + return nil +} + +type BanPublisherNode403JSONResponse ErrorResponse + +func (response BanPublisherNode403JSONResponse) VisitBanPublisherNodeResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + +type BanPublisherNode404JSONResponse ErrorResponse + +func (response BanPublisherNode404JSONResponse) VisitBanPublisherNodeResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + + return json.NewEncoder(w).Encode(response) +} + +type BanPublisherNode500JSONResponse ErrorResponse + +func (response BanPublisherNode500JSONResponse) VisitBanPublisherNodeResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + type GetPermissionOnPublisherNodesRequestObject struct { PublisherId string `json:"publisherId"` NodeId string `json:"nodeId"` @@ -1936,6 +2729,15 @@ func (response DeleteNodeVersion204Response) VisitDeleteNodeVersionResponse(w ht return nil } +type DeleteNodeVersion403JSONResponse ErrorResponse + +func (response DeleteNodeVersion403JSONResponse) VisitDeleteNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + type DeleteNodeVersion404JSONResponse Error func (response DeleteNodeVersion404JSONResponse) VisitDeleteNodeVersionResponse(w http.ResponseWriter) error { @@ -1945,6 +2747,15 @@ func (response DeleteNodeVersion404JSONResponse) VisitDeleteNodeVersionResponse( return json.NewEncoder(w).Encode(response) } +type DeleteNodeVersion500JSONResponse ErrorResponse + +func (response DeleteNodeVersion500JSONResponse) VisitDeleteNodeVersionResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + type UpdateNodeVersionRequestObject struct { PublisherId string `json:"publisherId"` NodeId string `json:"nodeId"` @@ -2129,53 +2940,104 @@ func (response CreatePersonalAccessToken403JSONResponse) VisitCreatePersonalAcce return json.NewEncoder(w).Encode(response) } -type CreatePersonalAccessToken500JSONResponse ErrorResponse +type CreatePersonalAccessToken500JSONResponse ErrorResponse + +func (response CreatePersonalAccessToken500JSONResponse) VisitCreatePersonalAccessTokenResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + +type DeletePersonalAccessTokenRequestObject struct { + PublisherId string `json:"publisherId"` + TokenId string `json:"tokenId"` +} + +type DeletePersonalAccessTokenResponseObject interface { + VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error +} + +type DeletePersonalAccessToken204Response struct { +} + +func (response DeletePersonalAccessToken204Response) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { + w.WriteHeader(204) + return nil +} + +type DeletePersonalAccessToken403JSONResponse ErrorResponse + +func (response DeletePersonalAccessToken403JSONResponse) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + +type DeletePersonalAccessToken404JSONResponse ErrorResponse + +func (response DeletePersonalAccessToken404JSONResponse) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + + return json.NewEncoder(w).Encode(response) +} + +type DeletePersonalAccessToken500JSONResponse ErrorResponse -func (response CreatePersonalAccessToken500JSONResponse) VisitCreatePersonalAccessTokenResponse(w http.ResponseWriter) error { +func (response DeletePersonalAccessToken500JSONResponse) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { w.Header().Set("Content-Type", "application/json") w.WriteHeader(500) return json.NewEncoder(w).Encode(response) } -type DeletePersonalAccessTokenRequestObject struct { - PublisherId string `json:"publisherId"` - TokenId string `json:"tokenId"` +type SecurityScanRequestObject struct { + Params SecurityScanParams } -type DeletePersonalAccessTokenResponseObject interface { - VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error +type SecurityScanResponseObject interface { + VisitSecurityScanResponse(w http.ResponseWriter) error } -type DeletePersonalAccessToken204Response struct { +type SecurityScan200Response struct { } -func (response DeletePersonalAccessToken204Response) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { - w.WriteHeader(204) +func (response SecurityScan200Response) VisitSecurityScanResponse(w http.ResponseWriter) error { + w.WriteHeader(200) return nil } -type DeletePersonalAccessToken403JSONResponse ErrorResponse +type SecurityScan400JSONResponse ErrorResponse -func (response DeletePersonalAccessToken403JSONResponse) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { +func (response SecurityScan400JSONResponse) VisitSecurityScanResponse(w http.ResponseWriter) error { w.Header().Set("Content-Type", "application/json") - w.WriteHeader(403) + w.WriteHeader(400) return json.NewEncoder(w).Encode(response) } -type DeletePersonalAccessToken404JSONResponse ErrorResponse +type SecurityScan401Response struct { +} -func (response DeletePersonalAccessToken404JSONResponse) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { +func (response SecurityScan401Response) VisitSecurityScanResponse(w http.ResponseWriter) error { + w.WriteHeader(401) + return nil +} + +type SecurityScan403JSONResponse ErrorResponse + +func (response SecurityScan403JSONResponse) VisitSecurityScanResponse(w http.ResponseWriter) error { w.Header().Set("Content-Type", "application/json") - w.WriteHeader(404) + w.WriteHeader(403) return json.NewEncoder(w).Encode(response) } -type DeletePersonalAccessToken500JSONResponse ErrorResponse +type SecurityScan500JSONResponse ErrorResponse -func (response DeletePersonalAccessToken500JSONResponse) VisitDeletePersonalAccessTokenResponse(w http.ResponseWriter) error { +func (response SecurityScan500JSONResponse) VisitSecurityScanResponse(w http.ResponseWriter) error { w.Header().Set("Content-Type", "application/json") w.WriteHeader(500) @@ -2283,8 +3145,94 @@ func (response ListPublishersForUser500JSONResponse) VisitListPublishersForUserR return json.NewEncoder(w).Encode(response) } +type ListAllNodeVersionsRequestObject struct { + Params ListAllNodeVersionsParams +} + +type ListAllNodeVersionsResponseObject interface { + VisitListAllNodeVersionsResponse(w http.ResponseWriter) error +} + +type ListAllNodeVersions200JSONResponse struct { + // Page Current page number + Page *int `json:"page,omitempty"` + + // PageSize Maximum number of node versions per page. Maximum is 100. + PageSize *int `json:"pageSize,omitempty"` + + // Total Total number of node versions available + Total *int `json:"total,omitempty"` + + // TotalPages Total number of pages available + TotalPages *int `json:"totalPages,omitempty"` + Versions *[]NodeVersion `json:"versions,omitempty"` +} + +func (response ListAllNodeVersions200JSONResponse) VisitListAllNodeVersionsResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + + return json.NewEncoder(w).Encode(response) +} + +type ListAllNodeVersions403JSONResponse ErrorResponse + +func (response ListAllNodeVersions403JSONResponse) VisitListAllNodeVersionsResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(403) + + return json.NewEncoder(w).Encode(response) +} + +type ListAllNodeVersions500JSONResponse ErrorResponse + +func (response ListAllNodeVersions500JSONResponse) VisitListAllNodeVersionsResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + +type GetWorkflowResultRequestObject struct { + WorkflowResultId string `json:"workflowResultId"` +} + +type GetWorkflowResultResponseObject interface { + VisitGetWorkflowResultResponse(w http.ResponseWriter) error +} + +type GetWorkflowResult200JSONResponse ActionJobResult + +func (response GetWorkflowResult200JSONResponse) VisitGetWorkflowResultResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + + return json.NewEncoder(w).Encode(response) +} + +type GetWorkflowResult404JSONResponse ErrorResponse + +func (response GetWorkflowResult404JSONResponse) VisitGetWorkflowResultResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(404) + + return json.NewEncoder(w).Encode(response) +} + +type GetWorkflowResult500JSONResponse ErrorResponse + +func (response GetWorkflowResult500JSONResponse) VisitGetWorkflowResultResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + // StrictServerInterface represents all server handlers. type StrictServerInterface interface { + // Admin Update Node Version Status + // (PUT /admin/nodes/{nodeId}/versions/{versionNumber}) + AdminUpdateNodeVersion(ctx context.Context, request AdminUpdateNodeVersionRequestObject) (AdminUpdateNodeVersionResponseObject, error) // Retrieve all distinct branches for a given repo // (GET /branch) GetBranch(ctx context.Context, request GetBranchRequestObject) (GetBranchResponseObject, error) @@ -2294,12 +3242,21 @@ type StrictServerInterface interface { // Retrieves a list of nodes // (GET /nodes) ListAllNodes(ctx context.Context, request ListAllNodesRequestObject) (ListAllNodesResponseObject, error) + // Reindex all nodes for searching. + // (POST /nodes/reindex) + ReindexNodes(ctx context.Context, request ReindexNodesRequestObject) (ReindexNodesResponseObject, error) + // Retrieves a list of nodes + // (GET /nodes/search) + SearchNodes(ctx context.Context, request SearchNodesRequestObject) (SearchNodesResponseObject, error) // Retrieve a specific node by ID // (GET /nodes/{nodeId}) GetNode(ctx context.Context, request GetNodeRequestObject) (GetNodeResponseObject, error) // Returns a node version to be installed. // (GET /nodes/{nodeId}/install) InstallNode(ctx context.Context, request InstallNodeRequestObject) (InstallNodeResponseObject, error) + // Add review to a specific version of a node + // (POST /nodes/{nodeId}/reviews) + PostNodeReview(ctx context.Context, request PostNodeReviewRequestObject) (PostNodeReviewResponseObject, error) // List all versions of a node // (GET /nodes/{nodeId}/versions) ListNodeVersions(ctx context.Context, request ListNodeVersionsRequestObject) (ListNodeVersionsResponseObject, error) @@ -2324,6 +3281,9 @@ type StrictServerInterface interface { // Update a publisher // (PUT /publishers/{publisherId}) UpdatePublisher(ctx context.Context, request UpdatePublisherRequestObject) (UpdatePublisherResponseObject, error) + // Ban a publisher + // (POST /publishers/{publisherId}/ban) + BanPublisher(ctx context.Context, request BanPublisherRequestObject) (BanPublisherResponseObject, error) // Retrieve all nodes // (GET /publishers/{publisherId}/nodes) ListNodesForPublisher(ctx context.Context, request ListNodesForPublisherRequestObject) (ListNodesForPublisherResponseObject, error) @@ -2336,6 +3296,9 @@ type StrictServerInterface interface { // Update a specific node // (PUT /publishers/{publisherId}/nodes/{nodeId}) UpdateNode(ctx context.Context, request UpdateNodeRequestObject) (UpdateNodeResponseObject, error) + // Ban a publisher's Node + // (POST /publishers/{publisherId}/nodes/{nodeId}/ban) + BanPublisherNode(ctx context.Context, request BanPublisherNodeRequestObject) (BanPublisherNodeResponseObject, error) // Retrieve permissions the user has for a given publisher // (GET /publishers/{publisherId}/nodes/{nodeId}/permissions) GetPermissionOnPublisherNodes(ctx context.Context, request GetPermissionOnPublisherNodesRequestObject) (GetPermissionOnPublisherNodesResponseObject, error) @@ -2360,6 +3323,9 @@ type StrictServerInterface interface { // Delete a specific personal access token // (DELETE /publishers/{publisherId}/tokens/{tokenId}) DeletePersonalAccessToken(ctx context.Context, request DeletePersonalAccessTokenRequestObject) (DeletePersonalAccessTokenResponseObject, error) + // Security Scan + // (GET /security-scan) + SecurityScan(ctx context.Context, request SecurityScanRequestObject) (SecurityScanResponseObject, error) // Receive artifacts (output files) from the ComfyUI GitHub Action // (POST /upload-artifact) PostUploadArtifact(ctx context.Context, request PostUploadArtifactRequestObject) (PostUploadArtifactResponseObject, error) @@ -2369,6 +3335,12 @@ type StrictServerInterface interface { // Retrieve all publishers for a given user // (GET /users/publishers/) ListPublishersForUser(ctx context.Context, request ListPublishersForUserRequestObject) (ListPublishersForUserResponseObject, error) + // List all node versions given some filters. + // (GET /versions) + ListAllNodeVersions(ctx context.Context, request ListAllNodeVersionsRequestObject) (ListAllNodeVersionsResponseObject, error) + // Retrieve a specific commit by ID + // (GET /workflowresult/{workflowResultId}) + GetWorkflowResult(ctx context.Context, request GetWorkflowResultRequestObject) (GetWorkflowResultResponseObject, error) } type StrictHandlerFunc = strictecho.StrictEchoHandlerFunc @@ -2383,6 +3355,38 @@ type strictHandler struct { middlewares []StrictMiddlewareFunc } +// AdminUpdateNodeVersion operation middleware +func (sh *strictHandler) AdminUpdateNodeVersion(ctx echo.Context, nodeId string, versionNumber string) error { + var request AdminUpdateNodeVersionRequestObject + + request.NodeId = nodeId + request.VersionNumber = versionNumber + + var body AdminUpdateNodeVersionJSONRequestBody + if err := ctx.Bind(&body); err != nil { + return err + } + request.Body = &body + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.AdminUpdateNodeVersion(ctx.Request().Context(), request.(AdminUpdateNodeVersionRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "AdminUpdateNodeVersion") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(AdminUpdateNodeVersionResponseObject); ok { + return validResponse.VisitAdminUpdateNodeVersionResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // GetBranch operation middleware func (sh *strictHandler) GetBranch(ctx echo.Context, params GetBranchParams) error { var request GetBranchRequestObject @@ -2458,6 +3462,54 @@ func (sh *strictHandler) ListAllNodes(ctx echo.Context, params ListAllNodesParam return nil } +// ReindexNodes operation middleware +func (sh *strictHandler) ReindexNodes(ctx echo.Context) error { + var request ReindexNodesRequestObject + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.ReindexNodes(ctx.Request().Context(), request.(ReindexNodesRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "ReindexNodes") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(ReindexNodesResponseObject); ok { + return validResponse.VisitReindexNodesResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + +// SearchNodes operation middleware +func (sh *strictHandler) SearchNodes(ctx echo.Context, params SearchNodesParams) error { + var request SearchNodesRequestObject + + request.Params = params + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.SearchNodes(ctx.Request().Context(), request.(SearchNodesRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "SearchNodes") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(SearchNodesResponseObject); ok { + return validResponse.VisitSearchNodesResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // GetNode operation middleware func (sh *strictHandler) GetNode(ctx echo.Context, nodeId string) error { var request GetNodeRequestObject @@ -2509,11 +3561,38 @@ func (sh *strictHandler) InstallNode(ctx echo.Context, nodeId string, params Ins return nil } +// PostNodeReview operation middleware +func (sh *strictHandler) PostNodeReview(ctx echo.Context, nodeId string, params PostNodeReviewParams) error { + var request PostNodeReviewRequestObject + + request.NodeId = nodeId + request.Params = params + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.PostNodeReview(ctx.Request().Context(), request.(PostNodeReviewRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "PostNodeReview") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(PostNodeReviewResponseObject); ok { + return validResponse.VisitPostNodeReviewResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // ListNodeVersions operation middleware -func (sh *strictHandler) ListNodeVersions(ctx echo.Context, nodeId string) error { +func (sh *strictHandler) ListNodeVersions(ctx echo.Context, nodeId string, params ListNodeVersionsParams) error { var request ListNodeVersionsRequestObject request.NodeId = nodeId + request.Params = params handler := func(ctx echo.Context, request interface{}) (interface{}, error) { return sh.ssi.ListNodeVersions(ctx.Request().Context(), request.(ListNodeVersionsRequestObject)) @@ -2718,11 +3797,37 @@ func (sh *strictHandler) UpdatePublisher(ctx echo.Context, publisherId string) e return nil } +// BanPublisher operation middleware +func (sh *strictHandler) BanPublisher(ctx echo.Context, publisherId string) error { + var request BanPublisherRequestObject + + request.PublisherId = publisherId + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.BanPublisher(ctx.Request().Context(), request.(BanPublisherRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "BanPublisher") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(BanPublisherResponseObject); ok { + return validResponse.VisitBanPublisherResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // ListNodesForPublisher operation middleware -func (sh *strictHandler) ListNodesForPublisher(ctx echo.Context, publisherId string) error { +func (sh *strictHandler) ListNodesForPublisher(ctx echo.Context, publisherId string, params ListNodesForPublisherParams) error { var request ListNodesForPublisherRequestObject request.PublisherId = publisherId + request.Params = params handler := func(ctx echo.Context, request interface{}) (interface{}, error) { return sh.ssi.ListNodesForPublisher(ctx.Request().Context(), request.(ListNodesForPublisherRequestObject)) @@ -2832,6 +3937,32 @@ func (sh *strictHandler) UpdateNode(ctx echo.Context, publisherId string, nodeId return nil } +// BanPublisherNode operation middleware +func (sh *strictHandler) BanPublisherNode(ctx echo.Context, publisherId string, nodeId string) error { + var request BanPublisherNodeRequestObject + + request.PublisherId = publisherId + request.NodeId = nodeId + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.BanPublisherNode(ctx.Request().Context(), request.(BanPublisherNodeRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "BanPublisherNode") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(BanPublisherNodeResponseObject); ok { + return validResponse.VisitBanPublisherNodeResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // GetPermissionOnPublisherNodes operation middleware func (sh *strictHandler) GetPermissionOnPublisherNodes(ctx echo.Context, publisherId string, nodeId string) error { var request GetPermissionOnPublisherNodesRequestObject @@ -3057,6 +4188,31 @@ func (sh *strictHandler) DeletePersonalAccessToken(ctx echo.Context, publisherId return nil } +// SecurityScan operation middleware +func (sh *strictHandler) SecurityScan(ctx echo.Context, params SecurityScanParams) error { + var request SecurityScanRequestObject + + request.Params = params + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.SecurityScan(ctx.Request().Context(), request.(SecurityScanRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "SecurityScan") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(SecurityScanResponseObject); ok { + return validResponse.VisitSecurityScanResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // PostUploadArtifact operation middleware func (sh *strictHandler) PostUploadArtifact(ctx echo.Context) error { var request PostUploadArtifactRequestObject @@ -3132,77 +4288,151 @@ func (sh *strictHandler) ListPublishersForUser(ctx echo.Context) error { return nil } +// ListAllNodeVersions operation middleware +func (sh *strictHandler) ListAllNodeVersions(ctx echo.Context, params ListAllNodeVersionsParams) error { + var request ListAllNodeVersionsRequestObject + + request.Params = params + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.ListAllNodeVersions(ctx.Request().Context(), request.(ListAllNodeVersionsRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "ListAllNodeVersions") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(ListAllNodeVersionsResponseObject); ok { + return validResponse.VisitListAllNodeVersionsResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + +// GetWorkflowResult operation middleware +func (sh *strictHandler) GetWorkflowResult(ctx echo.Context, workflowResultId string) error { + var request GetWorkflowResultRequestObject + + request.WorkflowResultId = workflowResultId + + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return sh.ssi.GetWorkflowResult(ctx.Request().Context(), request.(GetWorkflowResultRequestObject)) + } + for _, middleware := range sh.middlewares { + handler = middleware(handler, "GetWorkflowResult") + } + + response, err := handler(ctx, request) + + if err != nil { + return err + } else if validResponse, ok := response.(GetWorkflowResultResponseObject); ok { + return validResponse.VisitGetWorkflowResultResponse(ctx.Response()) + } else if response != nil { + return fmt.Errorf("unexpected response type: %T", response) + } + return nil +} + // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xdfXPTOLf/KhrfZ2ZhbkgKy90/+l8p0O1eWjqUssNCb0axTxKBLRlJbjfL9Lvf0Zst", - "x7Lj9CU0S2eeeejGsl7O+Z1XHUvfo5hlOaNApYh2v0cinkOG9Z97sSSM/sEm70AUqVQ/5ZzlwCUB3QDr", - "BmNe0DFJ1A8JiJiTXP0a7UaHCVBJpgQ4YlMk54B4QZGcE4G47hFNIGV0JpBk0SCSixyi3UhITugsuhqo", - "iWVEjudYzJudv58DUk9c16ZxRzehGapODl/27iIDIfAMwv3Yh707kyRr6emMkr+ReiwkznJ0OQfqdYku", - "sUAZTiAaRFPGMyyj3YhQ+dvzajBCJcyAq9GAJh1DAU30SG7WX9gEYYHw0hyG/caaETnmkLPwWOqJIJLx", - "BaI4gxBpZnkxNj8ud3BwcobUE1QISEKvhth7Rsm3AhCpcDhlvFynwaC/sqIgwb4V5rEkdDYWCyEha470", - "1rVApkXrNPOFZDyejy+AC/3qck8ni/eqAbINWjsSEvMuEOnnt8hbIRnHMxhPSapH/A+HabQb/deoUh8j", - "qztGp6bta9X0ahBdMv51mrLLsWZ7Y7bHuJqja9pc8VX5C5t8gViqjl9xznhTLSUgMUlFiEvqD5wi0wIS", - "RKhZuqI0nrBC6lmA6hcxjuaESqFBw0Gw9EIxmEhFMSIh0yM0GGN/wJzjhfrvVqWxh+IUMEeYJihmNCYC", - "kNfCkURPZrgGQd6ByBkV0CQMOHo1Ju1NsjkMh28F4ZBEu59sF9UL54FZHLMkMDgu5Lxl9BpdQs/ZJU0Z", - "TkQY6rTIJsbGlA0d8ShLYBjEM4lD0nf27g2SrHz1F4FUu2FfhaNmUzSUTnguVU8pliCkrxK6hEtR94Nt", - "ql4mMVhmN+eSYzl363lzuP/q+PQVUgKMCPXXWGnm4PzCYqv6T4jIU2z0+cpl5sUkJWIOfNUKT8qGCnxa", - "rYaHxxeg9Ayyqjc8AYMO3VW5zD6MX0EUiWcajn0VwVWLoHyouF6Xl3iO6QxSFlj7aZFlmC/Ugk0r4xAY", - "rhLhbEfQ+eCAJSR7soWhWIJWSNpwKGo4O6S8DvtyzWKoN55oMzQICXYONAEal5q5rgBTIqRaRU5y5LdF", - "TuWgyaLG0/5qN4GcQ6ymG3BMaULUI4HItEYxRASqXvT4PmEsBUx9ZXTG02bPn94WMi8kekvTxTmyoHJv", - "1Ify0NquXho/t/oN7z1eVbpngKYsTdml9kwgw1SS2DUjdDZER4VQbrjTWs4/ahHiFSA+yxUa3sG3AoRc", - "C9Fq9uVjawUnatI+BC3Why1Ia2X3n3OQc+C1vlYyOrTUE+BC+Q97cQxCvGdfISS37RJWQ0dY3KTq9LrC", - "VrOiYddniPZQxjhULlDA5dCTULaPSiWSiXJChzd3ug+I3HdR0UqfO2x03lm9METa8HtmR895iPYxVXDG", - "SJAsT2sOVViNOyZ2cMoMllvmI6y57wZULRAHWXCqiFmojg3vwkMGceVbxt5oCgOotLK3AqIbeD2O6+WE", - "huhQ/iIQ/qrMFENYoYorFg7R6ZwVaaIYl7JL4DFuwVvKZqzTdpeD/SKQahzsJQPlEXQaJNvE+UnVGnwb", - "1MuHOQLnfixbKIfwZmjHCh7DOGYJlOF0s1GR54zL4LNLmAgiW7z5dvTZqTYweAOeW0IG2cBZ2uJWqidO", - "rhVIgmxo9Kda9mbLmWrcTY4zESIGZJikLckU9QjhJOFKPxhaEKFXcL0Iokcf7b65Vo6rOggRwI/eG8tX", - "0cNYhRWBzIUXbLggw2YNbm47bEe64z7mQ0MlHhchJ01zOFbeWT+CbAUQiNhLMkJXOrtWouZYIKxeQDkn", - "FySFumPl+btE7OU5Zxc9HGknrQJh+0q4y1sHrdKHEBecyMWpknTDoReAOfC9woB1ov/rtQPOH3++jwYm", - "261np59WY82lzKMr1TGhU6OBiVQCEe2zbLpAeyeHkeeNR0+HOzZTSHFOot3o1+HO8Fk0iJSs6NmMJhzT", - "WE9lBjLk2igfQiCcpsg0BYMdjGbkAqgOSBU1bDaS0cMk2o0OQL4wHauxOM5Aatv2qS0Fq0R0SlIJHE10", - "aKtAE30rgC8ixxodKo9tlrZK/0hegKUZNguYYr0xEMWKKJgyushYIUaaRmeHAdadq/5MfkpT5dnOjnZz", - "mHI1NVVwnqcKVITR0Rdh3JBqyLoMOjrdLBC/WvZ7oj2KdGtlhMoxrgbR853nIcblDFEm0ZQVNFHN/scs", - "allWpHJ3UiSAXwA3eT0DXRPOGwxwAhegQZAQIQmNZTsa9NujGZF2s6EPtHhBxQB9YRPz/zYLLgbai/S1", - "rEBYCBYT5T+iSyLn5dBmtCAUD8q59EDj8s6LRibIeK7cWqxW2wZP0/4wiXw0NrREaEy2nKivpEFNZP/Q", - "jN0uGWUHp2YrYO0Z0EC6e+1ZuBePjYiuOQUDKDeJEj9rz2LiFM+a4+cKYzZrK5mKnjTq24bJjQcR0DtP", - "m5ndFqKXKWKtJtSohMZpkei4Tk+oa/RT8k/bDHb6TmEt3dvg6w9XtV/cNnBd2XY528v7x4EYSDKJ02PN", - "nLfTEzwDX4N7BO2lss1DpBaICVUy3l/ZtWp3k7G4bf3uZMvX6FZxap1OWWJI0a3PFXAJ1RraRa76TYRj", - "zoRR+GXQJJoa+w0Rci9Nj/VoKzT2iSe0Xu5S6JHvRnKPy9HMSEZXFJyWQtsybkqMCeorsLcrKmb0BteO", - "8N8kKzKPhmZV3lqa21QlEHpJnN55C4hZHtyA3C84Byp9dRycgxbSgL+ufm6sBl9gkuJJCu1dlXLe3Z+a", - "Vnd/vfRCm4wYiQ8K8gVOSYIIzQs5cFqFmF9b9cRxQEH0xk8XT+t7uoEVrqN4lM6oU8HtZH2KjBY4r/TP", - "6Lv65zC5WqmIlDKwu+6qZ4xEDjGZkrjcSGh4isdm32NJ5Whh1umFUpbNHDrjkNu2fKtFrMkF9bujgQHJ", - "r5uDwGvGJyRJgHrw3MzIetl0K7C/DEs0WaDDl70EYESokDhNuwTBypezi5V9t+/q9Q8QkHIvyuz7I+0C", - "lDOzKYWmzByabsJyc40yhNsRtYbNPl1aiD9szdlHh1PEMiIlJAOfHN4OndtWafOUq31uzxRjqSAR7Ub/", - "9/lz8t+fPw+9f/5zB25y71qNNp2hcFJuIIlC7y5NizRdDD0TtSlxqpk+UcRzUzI2wUnJGJN7Hf60Sm54", - "X7WcjQm0pDlmSYYm4FSQzceuVnf2bT8CaQYOHrjFfbXkvb1mr6JqKV3Y4MYb6z+p4MqntbgTXG6p0dVE", - "UhRyxDF+oat2sRAs4dOJwtF3+1fdFQ26lB9Kk3AneBx0ld34BhY9UgEBRmdnhy8ft5jbclX3yrntsFcv", - "28tml32ri0qeNicSjhHb5496EFolJFU+p1M7n1TNNqFDazWbqzSoV/hXzXLTvs4LnOgyQxBy4KJ74/ho", - "l+yeQ6eW2PPA4rH9/GoQ5UwE8LGvq4Iqlhn1A0K+YMni1lbsQaJeRa403FUDlE/vauDA5r8uTbG1UTWX", - "+z6i8LkhzXKlhKmnJ/9A8mM98fsiJHb/Xxt7f+f/07kymZUMGfArNQuXlQy1iVBd4440f7AprQomAvbn", - "EH8tyyGqOihX66aLI1IOOFkgib9CIND/YMfw5XNluB8YSTLkptsWQLu2G3VAlorLxF6Z4W2mhHkBfmlJ", - "SUH3ygBNcSoAMTkHfklqZYMdpbxNYJ253i3JlDHm5bbVfQjCMyIEoTOPDqYoTrP0fsakDsmKhbhNGPz8", - "fg8J/F7+baOBBFIwAlkXo5f69w4hCnjkXt83FInnLWVnev1myiHLs8GkRjWde+ku91Xnhs0+vNq9obbA", - "8UeBZGfTLo9Lvj5Ab91IrVJey7sGyx53EcCY+TBmgzD74d78xqFdaBL/a7z5B2lc1xAYGetjCLo8ikbl", - "TTjvLV4zvh1W4wZ1I6tT3/cuazPcRuTWsjptVRnduZze1RT3zKpU1RSbSw91VnBsQ1Jo2NuObHVuJi6E", - "ZNlyNtzftexW4rXype4o8c7FZ3B3G6LBWjRdi9QWZD4UJm1XZFvbWQubhvagY8ug/SNNzIaKBP9lkcqD", - "Ntmu8GiVNlnDsI5y4Dot3VUqdADypGz2lpZRU8tHB9thdG+y4xFj+ioh/pf7a+1UeJvnHvUfds/XSOd5", - "hKt/Ee1/jnMLaYRwQV04jrNdr1vDtM1Wuy4W1J4a1ydPodpe87w0d4zL2BzjMi4PgOk+984qy9rAbZ2d", - "B6X4dgPcJu2uSQ9BZhTCh1npYyT1Y3eSVZHbc6xgucrU7mSv/lK/aVrhsuzHoXkr4u+Hgot+nofVbDaq", - "v1alWz/lulwnujrm3ypNG+7otopIA2mED+4IWlrJ5SObVnh89zuIITh+aFZ4ruUEu5W4dTy+RhFmGW4v", - "RUbGwWY0NUcWVsfZYZp4h80hIbEslr+Ta4zdrAuq4vmfG7Z3kyIIHmH4A9IGHeXXpTxuQQJh+JBBWKm5", - "tjaJEFZtanW+bqt9B7C+kb9JbuFe78g+ZAN+umyAjo9WfK/RPFtWbH9hQejE3DXrDILnrv5cn9jbNW+v", - "3ah/txJkqJW2kJxp2KAjTPEMMrXelV+2BFC3jSV3IeG561RSy4HI+hAtuEwXZYVE+Dzk6+WADIe3qPbi", - "IfdznY9tQojplvMeVnX0Xf/b72uADSuGcOxr53v7CRtDwJ+88MMQ4V9W+XFN0TH5+ieYSzLFsezYfWJC", - "nunGe67tbe3rmEMix62nnU+K+CvI8Yozge239aYxupwDN8fdM3NOvz2wlJuDosM3ZumDE8cpm4nxLBYt", - "h1j7t+bYoxWRemWIYDZEM7E7Gpk5PFHTGqlHPa6uu9aFcvZQTtfmOjfJ+dd/2e7sV2Pm+3317HP08ePH", - "j0+Ojp68fPn+9993j452T0//+hyhR892nv725OnOk6c773d2dvX//nocnEeR4PYbzfaLBPvHGjVe3+QV", - "dV/YZLz2VXHBO+FEz1vgHMV5EbyQxwBYX6tW4TJ4N4FiIEYClLkyeWRubkjB5ZmXAh3sn2r8PhKPFYR9", - "8WiBsGkyaHvwLHh0f/CSv3erL/hruyqynfwtZNv09Xcr7rBrO9g36A77+8qakiUsSwJpfHn3IdY1Sl3y", - "G9pkUNO5y3Ov0c4TvuvtWt8k9dV56dzKmOHUc28QhxjIBZitaWfu6sfxdZz0yF26/0ZHu+oZlIML9MgX", - "vsdoyllmzlW2ZuWAyN+LCTKn5JpjXwvRdazHAUh9FcEd7kDYazEaxH77v+t+u9Q8HfP1+vt1ByBbrmiM", - "cZq6r6GHHvH8aKHn+SivGb8Nsj4ck3Lfj0mpJXQL0ZHL1RhV3ZuIUN8ioq+EUDYS52Sofckh47Po6vzq", - "/wMAAP//lLylj0N5AAA=", + "H4sIAAAAAAAC/+w9f2/btrZfhdC7wFY8x053+4aH/Jema5e7JQ2SpsPumudLS7TNViI1kkrqFfnuD/wl", + "URIpy3Hi2quBi9vMpA4PyfObhzxfophmOSWICB4dfYl4PEcZVH8exwJT8i86uUS8SIX8KWc0R0xgpDpA", + "1WH8kU7GOJE/JIjHDOfy1+goOk0QEXiKEQN0CsQcgY90AsQcc8AURDBBKSUzDgSNBpFY5Cg6irhgmMyi", + "+4EFzwrSEzwryCrgCzGnrA333RwB3WbhxjTLsPDCuJ2NbxnMAlBuEYMzBN5fHp+BgqMETBZ2HSpomAg0", + "Q0yCi2k2XagJT1M4436oqpOaq+o0BD8NZ0Pwn4ODlN4dSGT+48NUz2E8h3zuBytblk/YgPHthwRy+qo3", + "iAxxDmfID8c09gYmcBaAdE3wZyCbuYBZDu7miDggwR3kIIMJigbRlLIMCr0hP77w70+RwPEtYlzBbw53", + "cv3qGJhWtd0+hBFJOrBFJFHIugwDOYCNaQz7oTvDYsxQTv1jyRaOBWULQGCGfMj6tvma4D8LBHDFfVPK", + "SmQ157noFQX2LoSUGoLh2QyxccFRgBFlC7ibU2C6osQONfQBzWA8xwSNuYBanP2DoWl0FP3XqJJyIyPi", + "Rme685Xqez+IpGiDApPZmC+4QB6mfmt7AN0juMs5gp86BINs7i8VcjYmRTYJLVFepClg6M8CcQFMRx9K", + "CzGnJEy9F4t3lMXzpQScL4Tstz4gLiDr4lzV/ojcIGmiWEoUv1H2aZrSu8uCXOkP1KdUCvLxFKdoGYAr", + "3fe17Ho/iO4MvLHisdZEz2E1Pdu1vVj35S908hHFQgL+iTGtveo6OUEC4pT7aFf+AVOge6AEYKJXTW4S", + "nNBCKCyQhAsoA3NMBFfMzRCn6a0keyzkYmOBMjVCa0/ND5AxuFD8GBLyxyBOEWQAkgTElMSYI+D0sEui", + "kBmusCCXiOeUcNReGGTXqy00KiTbw0jGwgwl0dEfBkT1wY0Hi5pQaSER58U4hjmMsVh4qJ4KmIKTi2tA", + "tZIy4swr6hLMPy2FJTsB26kP1FlejPWPPpZ8c3ENZOsQoNkQnL8/fXV6DN4hnkLwy/8eejUIwQLDdBzn", + "hceC041qxvAW4hROUgQmaEoZKhleiQE+7AIuZxmGrtZgHfBeMW6hSyH+EOBWUy0XC13blaGMssVSMtDd", + "+uw/5WHBLimANjWg6a0p4npSEFGAXzEpPoMfDoeHL7waBOfjKUPorwCV5TgHuh3QQuSF1/CTylUpjzFH", + "rJR9NVBSc+hGuZRG3cKZO3HLtz55ck4TjxipvIa2LQoFmlG2CJjtptVuK6FJgLHdTz3DJPSOpBQmAfdA", + "WwBylLJjYEhHN+LYt+HXl78CQctPv+NA9hv2tRSV/dayFpdNP4UCceFSYZe2lZv03nSVH+MYGenvoSso", + "5nY+v56e/HR+9ROQGh1g4s6xsou9+PkZVsJPMM9TqK3ppdPMi0mK+VwbdV0zvCg7Sm2keK/b2zT86UfA", + "2IdKsdlp9tn4JYvSz7aSm+UaVfKvsTZHgiagKLixWJYuqTAuc1/zJMT0V+VkECkyqfirX49jgW+lr1T9", + "9AqlSCgTt/rtJSQEJY6BUOHgUmzbPphDMkMp9WzxVZFlUIsP3Us7rpp4MbeC2OskMwQFSo5FgG6hQMoQ", + "U7a2XGJrukvv2HxcM7LlFwfKcvfKrxyRBJHYK5WPQYq5kLOQYt7tC6ypZf0hu9H9zc0E5QxJQesLF5EE", + "yyYO8LS2YgBLCrMfOpQ1oTRFkLgy95p5KPWPt0pLgbckXdwAwzv2i/pQDgmHpega3GUIq8VkDEEeUum6", + "rfTiDc9pGvPyWaeBUC5qKfEHYErTlN4powFlkAgc226YzIbgrOACTEpdYREJ8HmIa+tTbzBvrbHGw7WW", + "OivXmgxHe1ouEEkkZp6m1ymczZbLgetcMtSlduFXEgrKsLDNxoGayIV2udiIi4CxEeaY3+ZIzBGrwVrK", + "K77tuUCMS9fzOI4R5+/oJ+QTfWEhVWMwv8QSEuhD5VXN3vJ7zUNwDDJp2pfes8dbVUhIK4kIKdUSUHA0", + "XD+u9gaLExsAXRpW85snl0a0DoEyER0DReE8BCeQSBaEgOMsT2u+uF/X2k3s2Ck9WG42H0C1+3ZA2QMw", + "JApG5GIWErDeO/+QXrpybaje1OQnoNIeexQiWsM+trteIjQEp+I7DuAnqekpgComKrdwCK7mtEgTuXEp", + "vUMshgF6S+mMdlp55WDfcSA7h5zNCWKdOt10sRZ1NQdXjfeyds+QNVSbSt5SeFtN0oLFaBzTBJVh7wfq", + "0hINR5MWeU6Z8AK9QxOORSCCFCZbM8cW8a5BLGYHvPvHaBrwXGSLFQgq4u7bvxY8G7XvtZDXsnP3crRV", + "d6OpVNyN3zus7fr47aBgFvQ9VBOAScKk1NILjblanod5wD1ghH1LJbKXAfCtrhuObk1fer9j6RZ7oviO", + "s2ydZBMGX1+jGUAKcB+lpugwHhc+61vtcCzN7n4LshOEgPlxkmGy1Isx7DqHHED5AcgZvsUpqpt7jiOD", + "+XGeM3rbw0OyooADaD7xg3wCom0fwjhCodV4JSDTVnur6bWy1XwtJ1TaOcIrNqS0R3HBsFhcSTmmSeQl", + "ggyx40Jzy0T912tLuf/67V000MkTanlUazXZuRB5dH+vgspTrZiwkBwZnajT/OOL08hxrKLnw0NzIklg", + "jqOj6J/Dw+EP0SCSzKqwGan9HkkniY++yH9Ok/uRgcBHX8xf5yrec68IvvAYQ8oMU6A4iCGxWw2gcr/K", + "6G5Uno5ScppIlS8/0X6LG9CQCDKYIaHshD++RJKGFdKRJZRI4xq5RyuCFcisHvQqUS+g2hRXgnejOyMu", + "XtJExb9iKm12tUAwz1PJB5iS0UfjNVeg6mJjqzzzNifVD7Dkqqgf9BGZQvuHw8OVpt87FnvfNIqj9/Y8", + "WFFNAnihHIJpkabKrnvxiKjUzwI9yLyEiT0zHwBMbmGKE4CJ9F0SKOBQI/Tcp9T0CQD+CyW60z83h/Vr", + "yiY4SRDRI7/Y3Mh28wgVYEoLoub+P5vcsVMipNuTAo7YLWL6aLgmq5XAcaX0HzeS07kOnFqhBbTUApJa", + "gZ2WYUsbQ/7DzpdHN3KE0YRBEiu5P0PC515LP5YDmKZAd0XaUoBghm8RUeHzthB9g8RLDbglN/3pOtIg", + "m+JUIAYmKhCvpOKfBWKLSizKjmOT0RMWiQmaQpXUp5POIKFkkdGCj5RCuj71iJebNWVHXXTadVovYt+m", + "k2MCVG/pz5RjVOzS3Licemm6N+2VxHWJBMNIqs5UnTkLTGIRpgb19WiGhclt60NarCB8AD7Sif5/k2zF", + "ByqS4drUHEDOaYyVnL3DYl4OrUfzkuKbEpce1NhM9FOUiUQ8V+JTzjZEnrq/MgA6FX6Ps+eKGyQiJ6d6", + "7DBnlACudH7XyhgQT7bOyljYD881i66IgiYoi0RJPytjMbGCZ8Xxc0lj5oxZUMAM1YeGybW/6JE7z9vn", + "0IFFLw+0lZiQo2ISp0WiYosKoa7Rr/BfIQwO+6Kwkuxt7etXF7UfbQp3Xdh2qeNm7rcnDieogKk2vt9O", + "L+AMuRLcWdBeIls3AjlBiInk8f7CLijdddT8seW75S1XohvBqWS6csmWy3NJuJgoCW2jp+pLAGNGuRb4", + "ZfyNtyX2r5iL4zQ9V6MtkdgXDtM6R5Bcjfw0nHtejqZH0rKiYKRk2sC4KdYq6MEM++CRjVQZT+xRX0s0", + "Vqdcj8uges4tWjmDn3FWZM7O6Rk582in8pTk14vPVZKTh7lzb9bmScEYIsJVAl4clGgIJaM1Z1Nmz4VB", + "ldKlG55EqxteL2kU4kzHSW2KD8d/HFhZZrzKoHQ630ZXqiXupKSqr0LlJGnZc1NJvRFDmCTos6Jwqs+y", + "61LrUnewUsvHQ015qT6QElbH62pxg+HXDBwMt3Tj9IJJ/aEZTCoqjiCL55jMhp37p7s9vfK6UuPsdVdz", + "5F/Q4o6yRI6ot6Kab2BIs2Mr2fF7DbnXkHsN+XU0pD2qWSpjJeObyzwSMgQ8RzGe4rhMjGtFcM51WuGT", + "HMHcPPHBgW8XVIjWXmj6dmLtatpkJ2i/SZZgsgCnr3oxwAgTLmCadjGC4S+rAyu/23yr5j8ACJd5ivr2", + "AFCueYlZ8BTzVIPx880DLjM82mlnI/W7MRF32FoQDpxOAc2wECgZuMvhZG/alLtQBKtKI3dUMRSSJKKj", + "6P8+fEj++8OHofPPP54gfLXGKeN5SSdlcuHXdRgaqo8X8VxfYp3ApNwYnQEz/GaF3La6UtbdcZMhJMdN", + "kBVBJitmubhj6BajOx52jS8oVxr8UnV8wlSK+kJUxiEXkJl4pnP5x7lY4vU/BOyVe1EZmF/DkngVvvPc", + "0GBBE1Y6/jZF/ymYJcgkt9t89O6crSdAk7hOUuYenQXLiy/L2cVmMjnmQTv+7egC/tS5R226FwVXxn71", + "bW+HspUN1DjnXpdHVkXEg0KLCn41Po4NLIFyizattRRbmBjERllxu1lQbZDcHbsxXp5rZLYsSyCsu6pe", + "l/Ppk//C18xcAxx8f06lPL++Pn31LGCOl7PaKue3w57tr7lKRb3XTr381S7F1GCSKpTdqY4uqm6bkN+1", + "m+HLpLdz77bCcsuyLrecdGpnGg6xONt+cz8IWPcn6kZZtWUPz0DuSRJ9cn+fP9XAnisa6naSuVe31bm/", + "O5D6u2uJt5r4pZhFdxUPhVioLnFHan+gvl3nDRSezFH8qby0Ul2Fs/ck1RWWlCGYLICAn5AnEPjejOHy", + "59JwoGckQYFFNxRgs303aoA07hfy4/IEqH1kxArkXgAqV9B+MgBTmHIEqJgjdodrV047roG3CevaQjdL", + "JpUxK9PNtiFIl2HOMZk566DvRaot3c6YlaVkuYUwxAzu+V8PDvxS/m28gUS9k9BWcvr9hA4m8ljkDuw1", + "WeJF4HKgmr9G2ad5Nhj0rNDZ6XsUeptd8gpbQyHH8WsRyeGmTR57OLMnvVU9tUp4NU8VmxZ34aExfc1n", + "g2T21a35jZP2LtzkW8Ga33PjqorAXKWDK9v0NYtiNIEkfCz2EpKtNCj0oxPgqkX7+0uju6ZzXkJSo+Fu", + "am3e7/AfS/HXlG2GbnczvXON3MvlR1N8K++T75p0r0U+Q5mN3fHO3hmJW2Z5VXkEmwuhdmZB7kLgtP+j", + "CTsdv4wLLmjWlcrQrTpqKcDdkZQnZ5/B06UHe/O5VT5vKBCzT+7drehPPW/KqxrCjvmOkfbXVDEbSrT/", + "m3nze2myWyGEZdJkBcXaP6Dwd1Sw+8DECoEJysAOZPY1AhTfcWAk+Qo8kSOmjjO7cmrfIHFRdntbZxO+", + "q3yyzkl5DMlPCRbeCEivi5Jl0pWz+vusqxWOgZyFq7936j6/sl74uSPzPHBTQ3+9au7rLluydbYgpjRS", + "n9id7PvAaj726fixfjp+XD46312mzRgQtYFDwG4e9Gbm8zXX7oHrwfGMIH8NElUtRzXbAiRFbsqPoObt", + "JZMB1ev10IZ9ge5KOJaadyImtU/U62eNG8lmIl0PypDuJ1yb9wuWx8F2StJ2Plj8FJZ/+bguqfjyexNq", + "e/Z1Y23ep2M3ckHh6z9a68FgJefY7qbdy2cPuMAw8L/8bRxvSlJdbasqIwRJ4hT5sc9O19+gaI3dzql9", + "2Ovgfz/WfZrQobd01P7B7/2D3/sHv1vBRb9ok7NzZVvjsvuqhs468ZWtzoHdR0S+uYiI8hGX3HVs1/Tj", + "W03G/a5UeioVrph/5K139209X2XmvLt6o37n07uhhtt8fKbIBpxBAmcok/NdeivUQ3W7mK7uY56nDqcF", + "ClGqh+PRXbooM6f8dSgfFgfTO7xDOVn7+NdDLqr6KKabz3to1dEX9W+/m3QbFgx+39fg+/hBK72A33hC", + "mF6Ev1lG2ANZx454wGOdweG9731RpKnRzKrydv01IuXjxZQkRSyABQgkQO871LbMHwxxViNHPsPkeNZV", + "t2UQfT6Y0QPzo8AZGr4q9JDh16QyWD6JvvrLaY3HImNIAq+l74Muu6vBSmazJAsUzSq20Ud9B5AJPIWx", + "6H5i8Fp1PrZ9H+tIWO+X3xDTbfWCUb6qp/B2Nr5lMAtAuUWq5gvMaEGUu/X+8vhMur+JvZfPCteic97I", + "1tWOxsHS0ZMi/oTEeEkpUzMB3RnczRHTtcOpLnpuKm8xXd8WeSvZqgpA45TO+HgW80DtXV1nSdffNTWC", + "gPxkCNBsCGb8aDTSOBxItEayKTwWK8h4miq56xtHNdUWUX0mlzIAM8NiPId87l1I054hzr3vo7+rKobZ", + "PuFRVLF1LwhVtr1egMxgrx/Jkm0fot9///33g7Ozg1ev3v3889HZ2dHV1b8/ROD7Hw6f/3jw/PDg+eG7", + "w8Mj9b9/P/PiUSTQPatvvvWeQPdt4dbniCQdc0C2/LyZx0c60Q/DXhP8WbVwAbO8VoMeE/HjCy+Ff6ST", + "8WpFmD/SiQ9pCUgwPJshNrYVvj1PIXPEwN2cAtMVJV0wMxjPMUFjLqBYGoM5052vVN/7QUQ9lPu2WQqu", + "IQV8SGguHSsuLZnPW81eUhQEHElzQJ+ASRhAzGFZoYqDNydXikm/588kn7oyIMCnussg1PCDt/I1gp86", + "ZKJsfpBAzNnYVCfww5UmllEMrTIGDnYLMackzCEKlOpTHhj22Kh8ISiL50vByk6rwFVlF70lIDkWlC2A", + "eZan/WFBVuStAAZcQNYl1lT7IwqFfuWJ29W27wdllcKeOtF298ZT3OQstQmluCrXVrF5Xb3U1UBLtdSV", + "ehPh2lo7ktil/Bb9eoTfwBo25WLebKDKct20cpTpA2JVbh44YChG+NYIa2sv1kssdFTvYNXzx2uU0VMY", + "lINz8L0rOp+BKaOZrmFpLJ83WPxcTICuSGgMXt71FOMbJK65eWLviU6+FXzPYr/9ZdX3JtoVT16vnify", + "BgnP+6DKOIJpal+wGjqL50aper5p+Zqyx1jW/dOW2/60Ze0g0UjB4Blir3e6TZ3KJU91N2IiZe5OZ7wS", + "fc5TlY2sI5Ibf6pb6oVFKn+Q7Bd9m0VzHzfFYL3KWiX6PYuFVYHDcmGA7Yk5eH54OFy3gpcTm9xgJa9B", + "5PLmI7wI38PY2No34rf3sfY6gWihy2mGTLFpPgznIlmTVz+aOfpi/1vXbF7ycvtvtc69jpWa8LfmKb5W", + "ter2npha0DVTd4OnPB2lqLf7lXQTXdMv8N1rq1QC0ERSsDQ6iuZC5PxoNII5Hqro4ZCyWXR/c///AQAA", + "//8uUHTnXaoAAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/ent/ciworkflowresult.go b/ent/ciworkflowresult.go index 597504c..fdbdcec 100644 --- a/ent/ciworkflowresult.go +++ b/ent/ciworkflowresult.go @@ -3,10 +3,11 @@ package ent import ( + "encoding/json" "fmt" "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" - "registry-backend/ent/storagefile" + "registry-backend/ent/schema" "strings" "time" @@ -26,26 +27,39 @@ type CIWorkflowResult struct { UpdateTime time.Time `json:"update_time,omitempty"` // OperatingSystem holds the value of the "operating_system" field. OperatingSystem string `json:"operating_system,omitempty"` - // GpuType holds the value of the "gpu_type" field. - GpuType string `json:"gpu_type,omitempty"` - // PytorchVersion holds the value of the "pytorch_version" field. - PytorchVersion string `json:"pytorch_version,omitempty"` // WorkflowName holds the value of the "workflow_name" field. WorkflowName string `json:"workflow_name,omitempty"` // RunID holds the value of the "run_id" field. RunID string `json:"run_id,omitempty"` + // JobID holds the value of the "job_id" field. + JobID string `json:"job_id,omitempty"` // Status holds the value of the "status" field. - Status string `json:"status,omitempty"` + Status schema.WorkflowRunStatusType `json:"status,omitempty"` // StartTime holds the value of the "start_time" field. StartTime int64 `json:"start_time,omitempty"` // EndTime holds the value of the "end_time" field. EndTime int64 `json:"end_time,omitempty"` + // PythonVersion holds the value of the "python_version" field. + PythonVersion string `json:"python_version,omitempty"` + // PytorchVersion holds the value of the "pytorch_version" field. + PytorchVersion string `json:"pytorch_version,omitempty"` + // CudaVersion holds the value of the "cuda_version" field. + CudaVersion string `json:"cuda_version,omitempty"` + // ComfyRunFlags holds the value of the "comfy_run_flags" field. + ComfyRunFlags string `json:"comfy_run_flags,omitempty"` + // Average amount of VRAM used by the workflow in Megabytes + AvgVram int `json:"avg_vram,omitempty"` + // Peak amount of VRAM used by the workflow in Megabytes + PeakVram int `json:"peak_vram,omitempty"` + // User who triggered the job + JobTriggerUser string `json:"job_trigger_user,omitempty"` + // Stores miscellaneous metadata for each workflow run. + Metadata map[string]interface{} `json:"metadata,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the CIWorkflowResultQuery when eager-loading is set. - Edges CIWorkflowResultEdges `json:"edges"` - ci_workflow_result_storage_file *uuid.UUID - git_commit_results *uuid.UUID - selectValues sql.SelectValues + Edges CIWorkflowResultEdges `json:"edges"` + git_commit_results *uuid.UUID + selectValues sql.SelectValues } // CIWorkflowResultEdges holds the relations/edges for other nodes in the graph. @@ -53,7 +67,7 @@ type CIWorkflowResultEdges struct { // Gitcommit holds the value of the gitcommit edge. Gitcommit *GitCommit `json:"gitcommit,omitempty"` // StorageFile holds the value of the storage_file edge. - StorageFile *StorageFile `json:"storage_file,omitempty"` + StorageFile []*StorageFile `json:"storage_file,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. loadedTypes [2]bool @@ -71,12 +85,10 @@ func (e CIWorkflowResultEdges) GitcommitOrErr() (*GitCommit, error) { } // StorageFileOrErr returns the StorageFile value or an error if the edge -// was not loaded in eager-loading, or loaded but was not found. -func (e CIWorkflowResultEdges) StorageFileOrErr() (*StorageFile, error) { - if e.StorageFile != nil { +// was not loaded in eager-loading. +func (e CIWorkflowResultEdges) StorageFileOrErr() ([]*StorageFile, error) { + if e.loadedTypes[1] { return e.StorageFile, nil - } else if e.loadedTypes[1] { - return nil, &NotFoundError{label: storagefile.Label} } return nil, &NotLoadedError{edge: "storage_file"} } @@ -86,17 +98,17 @@ func (*CIWorkflowResult) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case ciworkflowresult.FieldStartTime, ciworkflowresult.FieldEndTime: + case ciworkflowresult.FieldMetadata: + values[i] = new([]byte) + case ciworkflowresult.FieldStartTime, ciworkflowresult.FieldEndTime, ciworkflowresult.FieldAvgVram, ciworkflowresult.FieldPeakVram: values[i] = new(sql.NullInt64) - case ciworkflowresult.FieldOperatingSystem, ciworkflowresult.FieldGpuType, ciworkflowresult.FieldPytorchVersion, ciworkflowresult.FieldWorkflowName, ciworkflowresult.FieldRunID, ciworkflowresult.FieldStatus: + case ciworkflowresult.FieldOperatingSystem, ciworkflowresult.FieldWorkflowName, ciworkflowresult.FieldRunID, ciworkflowresult.FieldJobID, ciworkflowresult.FieldStatus, ciworkflowresult.FieldPythonVersion, ciworkflowresult.FieldPytorchVersion, ciworkflowresult.FieldCudaVersion, ciworkflowresult.FieldComfyRunFlags, ciworkflowresult.FieldJobTriggerUser: values[i] = new(sql.NullString) case ciworkflowresult.FieldCreateTime, ciworkflowresult.FieldUpdateTime: values[i] = new(sql.NullTime) case ciworkflowresult.FieldID: values[i] = new(uuid.UUID) - case ciworkflowresult.ForeignKeys[0]: // ci_workflow_result_storage_file - values[i] = &sql.NullScanner{S: new(uuid.UUID)} - case ciworkflowresult.ForeignKeys[1]: // git_commit_results + case ciworkflowresult.ForeignKeys[0]: // git_commit_results values[i] = &sql.NullScanner{S: new(uuid.UUID)} default: values[i] = new(sql.UnknownType) @@ -137,18 +149,6 @@ func (cwr *CIWorkflowResult) assignValues(columns []string, values []any) error } else if value.Valid { cwr.OperatingSystem = value.String } - case ciworkflowresult.FieldGpuType: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field gpu_type", values[i]) - } else if value.Valid { - cwr.GpuType = value.String - } - case ciworkflowresult.FieldPytorchVersion: - if value, ok := values[i].(*sql.NullString); !ok { - return fmt.Errorf("unexpected type %T for field pytorch_version", values[i]) - } else if value.Valid { - cwr.PytorchVersion = value.String - } case ciworkflowresult.FieldWorkflowName: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field workflow_name", values[i]) @@ -161,11 +161,17 @@ func (cwr *CIWorkflowResult) assignValues(columns []string, values []any) error } else if value.Valid { cwr.RunID = value.String } + case ciworkflowresult.FieldJobID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field job_id", values[i]) + } else if value.Valid { + cwr.JobID = value.String + } case ciworkflowresult.FieldStatus: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field status", values[i]) } else if value.Valid { - cwr.Status = value.String + cwr.Status = schema.WorkflowRunStatusType(value.String) } case ciworkflowresult.FieldStartTime: if value, ok := values[i].(*sql.NullInt64); !ok { @@ -179,14 +185,57 @@ func (cwr *CIWorkflowResult) assignValues(columns []string, values []any) error } else if value.Valid { cwr.EndTime = value.Int64 } - case ciworkflowresult.ForeignKeys[0]: - if value, ok := values[i].(*sql.NullScanner); !ok { - return fmt.Errorf("unexpected type %T for field ci_workflow_result_storage_file", values[i]) + case ciworkflowresult.FieldPythonVersion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field python_version", values[i]) + } else if value.Valid { + cwr.PythonVersion = value.String + } + case ciworkflowresult.FieldPytorchVersion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field pytorch_version", values[i]) + } else if value.Valid { + cwr.PytorchVersion = value.String + } + case ciworkflowresult.FieldCudaVersion: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field cuda_version", values[i]) } else if value.Valid { - cwr.ci_workflow_result_storage_file = new(uuid.UUID) - *cwr.ci_workflow_result_storage_file = *value.S.(*uuid.UUID) + cwr.CudaVersion = value.String } - case ciworkflowresult.ForeignKeys[1]: + case ciworkflowresult.FieldComfyRunFlags: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field comfy_run_flags", values[i]) + } else if value.Valid { + cwr.ComfyRunFlags = value.String + } + case ciworkflowresult.FieldAvgVram: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field avg_vram", values[i]) + } else if value.Valid { + cwr.AvgVram = int(value.Int64) + } + case ciworkflowresult.FieldPeakVram: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field peak_vram", values[i]) + } else if value.Valid { + cwr.PeakVram = int(value.Int64) + } + case ciworkflowresult.FieldJobTriggerUser: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field job_trigger_user", values[i]) + } else if value.Valid { + cwr.JobTriggerUser = value.String + } + case ciworkflowresult.FieldMetadata: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field metadata", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &cwr.Metadata); err != nil { + return fmt.Errorf("unmarshal field metadata: %w", err) + } + } + case ciworkflowresult.ForeignKeys[0]: if value, ok := values[i].(*sql.NullScanner); !ok { return fmt.Errorf("unexpected type %T for field git_commit_results", values[i]) } else if value.Valid { @@ -248,26 +297,47 @@ func (cwr *CIWorkflowResult) String() string { builder.WriteString("operating_system=") builder.WriteString(cwr.OperatingSystem) builder.WriteString(", ") - builder.WriteString("gpu_type=") - builder.WriteString(cwr.GpuType) - builder.WriteString(", ") - builder.WriteString("pytorch_version=") - builder.WriteString(cwr.PytorchVersion) - builder.WriteString(", ") builder.WriteString("workflow_name=") builder.WriteString(cwr.WorkflowName) builder.WriteString(", ") builder.WriteString("run_id=") builder.WriteString(cwr.RunID) builder.WriteString(", ") + builder.WriteString("job_id=") + builder.WriteString(cwr.JobID) + builder.WriteString(", ") builder.WriteString("status=") - builder.WriteString(cwr.Status) + builder.WriteString(fmt.Sprintf("%v", cwr.Status)) builder.WriteString(", ") builder.WriteString("start_time=") builder.WriteString(fmt.Sprintf("%v", cwr.StartTime)) builder.WriteString(", ") builder.WriteString("end_time=") builder.WriteString(fmt.Sprintf("%v", cwr.EndTime)) + builder.WriteString(", ") + builder.WriteString("python_version=") + builder.WriteString(cwr.PythonVersion) + builder.WriteString(", ") + builder.WriteString("pytorch_version=") + builder.WriteString(cwr.PytorchVersion) + builder.WriteString(", ") + builder.WriteString("cuda_version=") + builder.WriteString(cwr.CudaVersion) + builder.WriteString(", ") + builder.WriteString("comfy_run_flags=") + builder.WriteString(cwr.ComfyRunFlags) + builder.WriteString(", ") + builder.WriteString("avg_vram=") + builder.WriteString(fmt.Sprintf("%v", cwr.AvgVram)) + builder.WriteString(", ") + builder.WriteString("peak_vram=") + builder.WriteString(fmt.Sprintf("%v", cwr.PeakVram)) + builder.WriteString(", ") + builder.WriteString("job_trigger_user=") + builder.WriteString(cwr.JobTriggerUser) + builder.WriteString(", ") + builder.WriteString("metadata=") + builder.WriteString(fmt.Sprintf("%v", cwr.Metadata)) builder.WriteByte(')') return builder.String() } diff --git a/ent/ciworkflowresult/ciworkflowresult.go b/ent/ciworkflowresult/ciworkflowresult.go index de32cde..dfb5e00 100644 --- a/ent/ciworkflowresult/ciworkflowresult.go +++ b/ent/ciworkflowresult/ciworkflowresult.go @@ -3,6 +3,7 @@ package ciworkflowresult import ( + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -21,20 +22,34 @@ const ( FieldUpdateTime = "update_time" // FieldOperatingSystem holds the string denoting the operating_system field in the database. FieldOperatingSystem = "operating_system" - // FieldGpuType holds the string denoting the gpu_type field in the database. - FieldGpuType = "gpu_type" - // FieldPytorchVersion holds the string denoting the pytorch_version field in the database. - FieldPytorchVersion = "pytorch_version" // FieldWorkflowName holds the string denoting the workflow_name field in the database. FieldWorkflowName = "workflow_name" // FieldRunID holds the string denoting the run_id field in the database. FieldRunID = "run_id" + // FieldJobID holds the string denoting the job_id field in the database. + FieldJobID = "job_id" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" // FieldStartTime holds the string denoting the start_time field in the database. FieldStartTime = "start_time" // FieldEndTime holds the string denoting the end_time field in the database. FieldEndTime = "end_time" + // FieldPythonVersion holds the string denoting the python_version field in the database. + FieldPythonVersion = "python_version" + // FieldPytorchVersion holds the string denoting the pytorch_version field in the database. + FieldPytorchVersion = "pytorch_version" + // FieldCudaVersion holds the string denoting the cuda_version field in the database. + FieldCudaVersion = "cuda_version" + // FieldComfyRunFlags holds the string denoting the comfy_run_flags field in the database. + FieldComfyRunFlags = "comfy_run_flags" + // FieldAvgVram holds the string denoting the avg_vram field in the database. + FieldAvgVram = "avg_vram" + // FieldPeakVram holds the string denoting the peak_vram field in the database. + FieldPeakVram = "peak_vram" + // FieldJobTriggerUser holds the string denoting the job_trigger_user field in the database. + FieldJobTriggerUser = "job_trigger_user" + // FieldMetadata holds the string denoting the metadata field in the database. + FieldMetadata = "metadata" // EdgeGitcommit holds the string denoting the gitcommit edge name in mutations. EdgeGitcommit = "gitcommit" // EdgeStorageFile holds the string denoting the storage_file edge name in mutations. @@ -49,7 +64,7 @@ const ( // GitcommitColumn is the table column denoting the gitcommit relation/edge. GitcommitColumn = "git_commit_results" // StorageFileTable is the table that holds the storage_file relation/edge. - StorageFileTable = "ci_workflow_results" + StorageFileTable = "storage_files" // StorageFileInverseTable is the table name for the StorageFile entity. // It exists in this package in order to avoid circular dependency with the "storagefile" package. StorageFileInverseTable = "storage_files" @@ -63,19 +78,25 @@ var Columns = []string{ FieldCreateTime, FieldUpdateTime, FieldOperatingSystem, - FieldGpuType, - FieldPytorchVersion, FieldWorkflowName, FieldRunID, + FieldJobID, FieldStatus, FieldStartTime, FieldEndTime, + FieldPythonVersion, + FieldPytorchVersion, + FieldCudaVersion, + FieldComfyRunFlags, + FieldAvgVram, + FieldPeakVram, + FieldJobTriggerUser, + FieldMetadata, } // ForeignKeys holds the SQL foreign-keys that are owned by the "ci_workflow_results" // table and are not defined as standalone fields in the schema. var ForeignKeys = []string{ - "ci_workflow_result_storage_file", "git_commit_results", } @@ -101,6 +122,8 @@ var ( DefaultUpdateTime func() time.Time // UpdateDefaultUpdateTime holds the default value on update for the "update_time" field. UpdateDefaultUpdateTime func() time.Time + // DefaultStatus holds the default value on creation for the "status" field. + DefaultStatus schema.WorkflowRunStatusType // DefaultID holds the default value on creation for the "id" field. DefaultID func() uuid.UUID ) @@ -128,16 +151,6 @@ func ByOperatingSystem(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOperatingSystem, opts...).ToFunc() } -// ByGpuType orders the results by the gpu_type field. -func ByGpuType(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldGpuType, opts...).ToFunc() -} - -// ByPytorchVersion orders the results by the pytorch_version field. -func ByPytorchVersion(opts ...sql.OrderTermOption) OrderOption { - return sql.OrderByField(FieldPytorchVersion, opts...).ToFunc() -} - // ByWorkflowName orders the results by the workflow_name field. func ByWorkflowName(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldWorkflowName, opts...).ToFunc() @@ -148,6 +161,11 @@ func ByRunID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldRunID, opts...).ToFunc() } +// ByJobID orders the results by the job_id field. +func ByJobID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldJobID, opts...).ToFunc() +} + // ByStatus orders the results by the status field. func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() @@ -163,6 +181,41 @@ func ByEndTime(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldEndTime, opts...).ToFunc() } +// ByPythonVersion orders the results by the python_version field. +func ByPythonVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPythonVersion, opts...).ToFunc() +} + +// ByPytorchVersion orders the results by the pytorch_version field. +func ByPytorchVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPytorchVersion, opts...).ToFunc() +} + +// ByCudaVersion orders the results by the cuda_version field. +func ByCudaVersion(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCudaVersion, opts...).ToFunc() +} + +// ByComfyRunFlags orders the results by the comfy_run_flags field. +func ByComfyRunFlags(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldComfyRunFlags, opts...).ToFunc() +} + +// ByAvgVram orders the results by the avg_vram field. +func ByAvgVram(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAvgVram, opts...).ToFunc() +} + +// ByPeakVram orders the results by the peak_vram field. +func ByPeakVram(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPeakVram, opts...).ToFunc() +} + +// ByJobTriggerUser orders the results by the job_trigger_user field. +func ByJobTriggerUser(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldJobTriggerUser, opts...).ToFunc() +} + // ByGitcommitField orders the results by gitcommit field. func ByGitcommitField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -170,10 +223,17 @@ func ByGitcommitField(field string, opts ...sql.OrderTermOption) OrderOption { } } -// ByStorageFileField orders the results by storage_file field. -func ByStorageFileField(field string, opts ...sql.OrderTermOption) OrderOption { +// ByStorageFileCount orders the results by storage_file count. +func ByStorageFileCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newStorageFileStep(), opts...) + } +} + +// ByStorageFile orders the results by storage_file terms. +func ByStorageFile(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { return func(s *sql.Selector) { - sqlgraph.OrderByNeighborTerms(s, newStorageFileStep(), sql.OrderByField(field, opts...)) + sqlgraph.OrderByNeighborTerms(s, newStorageFileStep(), append([]sql.OrderTerm{term}, terms...)...) } } func newGitcommitStep() *sqlgraph.Step { @@ -187,6 +247,6 @@ func newStorageFileStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), sqlgraph.To(StorageFileInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, false, StorageFileTable, StorageFileColumn), + sqlgraph.Edge(sqlgraph.O2M, false, StorageFileTable, StorageFileColumn), ) } diff --git a/ent/ciworkflowresult/where.go b/ent/ciworkflowresult/where.go index 7e4703e..10133bd 100644 --- a/ent/ciworkflowresult/where.go +++ b/ent/ciworkflowresult/where.go @@ -4,6 +4,7 @@ package ciworkflowresult import ( "registry-backend/ent/predicate" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -71,16 +72,6 @@ func OperatingSystem(v string) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldEQ(FieldOperatingSystem, v)) } -// GpuType applies equality check predicate on the "gpu_type" field. It's identical to GpuTypeEQ. -func GpuType(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEQ(FieldGpuType, v)) -} - -// PytorchVersion applies equality check predicate on the "pytorch_version" field. It's identical to PytorchVersionEQ. -func PytorchVersion(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEQ(FieldPytorchVersion, v)) -} - // WorkflowName applies equality check predicate on the "workflow_name" field. It's identical to WorkflowNameEQ. func WorkflowName(v string) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldEQ(FieldWorkflowName, v)) @@ -91,9 +82,15 @@ func RunID(v string) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldEQ(FieldRunID, v)) } +// JobID applies equality check predicate on the "job_id" field. It's identical to JobIDEQ. +func JobID(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldJobID, v)) +} + // Status applies equality check predicate on the "status" field. It's identical to StatusEQ. -func Status(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEQ(FieldStatus, v)) +func Status(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldEQ(FieldStatus, vc)) } // StartTime applies equality check predicate on the "start_time" field. It's identical to StartTimeEQ. @@ -106,6 +103,41 @@ func EndTime(v int64) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldEQ(FieldEndTime, v)) } +// PythonVersion applies equality check predicate on the "python_version" field. It's identical to PythonVersionEQ. +func PythonVersion(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldPythonVersion, v)) +} + +// PytorchVersion applies equality check predicate on the "pytorch_version" field. It's identical to PytorchVersionEQ. +func PytorchVersion(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldPytorchVersion, v)) +} + +// CudaVersion applies equality check predicate on the "cuda_version" field. It's identical to CudaVersionEQ. +func CudaVersion(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldCudaVersion, v)) +} + +// ComfyRunFlags applies equality check predicate on the "comfy_run_flags" field. It's identical to ComfyRunFlagsEQ. +func ComfyRunFlags(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldComfyRunFlags, v)) +} + +// AvgVram applies equality check predicate on the "avg_vram" field. It's identical to AvgVramEQ. +func AvgVram(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldAvgVram, v)) +} + +// PeakVram applies equality check predicate on the "peak_vram" field. It's identical to PeakVramEQ. +func PeakVram(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldPeakVram, v)) +} + +// JobTriggerUser applies equality check predicate on the "job_trigger_user" field. It's identical to JobTriggerUserEQ. +func JobTriggerUser(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldJobTriggerUser, v)) +} + // CreateTimeEQ applies the EQ predicate on the "create_time" field. func CreateTimeEQ(v time.Time) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldEQ(FieldCreateTime, v)) @@ -251,156 +283,6 @@ func OperatingSystemContainsFold(v string) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldOperatingSystem, v)) } -// GpuTypeEQ applies the EQ predicate on the "gpu_type" field. -func GpuTypeEQ(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEQ(FieldGpuType, v)) -} - -// GpuTypeNEQ applies the NEQ predicate on the "gpu_type" field. -func GpuTypeNEQ(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNEQ(FieldGpuType, v)) -} - -// GpuTypeIn applies the In predicate on the "gpu_type" field. -func GpuTypeIn(vs ...string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldIn(FieldGpuType, vs...)) -} - -// GpuTypeNotIn applies the NotIn predicate on the "gpu_type" field. -func GpuTypeNotIn(vs ...string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNotIn(FieldGpuType, vs...)) -} - -// GpuTypeGT applies the GT predicate on the "gpu_type" field. -func GpuTypeGT(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldGT(FieldGpuType, v)) -} - -// GpuTypeGTE applies the GTE predicate on the "gpu_type" field. -func GpuTypeGTE(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldGTE(FieldGpuType, v)) -} - -// GpuTypeLT applies the LT predicate on the "gpu_type" field. -func GpuTypeLT(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldLT(FieldGpuType, v)) -} - -// GpuTypeLTE applies the LTE predicate on the "gpu_type" field. -func GpuTypeLTE(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldLTE(FieldGpuType, v)) -} - -// GpuTypeContains applies the Contains predicate on the "gpu_type" field. -func GpuTypeContains(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldContains(FieldGpuType, v)) -} - -// GpuTypeHasPrefix applies the HasPrefix predicate on the "gpu_type" field. -func GpuTypeHasPrefix(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldGpuType, v)) -} - -// GpuTypeHasSuffix applies the HasSuffix predicate on the "gpu_type" field. -func GpuTypeHasSuffix(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldGpuType, v)) -} - -// GpuTypeIsNil applies the IsNil predicate on the "gpu_type" field. -func GpuTypeIsNil() predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldIsNull(FieldGpuType)) -} - -// GpuTypeNotNil applies the NotNil predicate on the "gpu_type" field. -func GpuTypeNotNil() predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNotNull(FieldGpuType)) -} - -// GpuTypeEqualFold applies the EqualFold predicate on the "gpu_type" field. -func GpuTypeEqualFold(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldGpuType, v)) -} - -// GpuTypeContainsFold applies the ContainsFold predicate on the "gpu_type" field. -func GpuTypeContainsFold(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldGpuType, v)) -} - -// PytorchVersionEQ applies the EQ predicate on the "pytorch_version" field. -func PytorchVersionEQ(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEQ(FieldPytorchVersion, v)) -} - -// PytorchVersionNEQ applies the NEQ predicate on the "pytorch_version" field. -func PytorchVersionNEQ(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNEQ(FieldPytorchVersion, v)) -} - -// PytorchVersionIn applies the In predicate on the "pytorch_version" field. -func PytorchVersionIn(vs ...string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldIn(FieldPytorchVersion, vs...)) -} - -// PytorchVersionNotIn applies the NotIn predicate on the "pytorch_version" field. -func PytorchVersionNotIn(vs ...string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNotIn(FieldPytorchVersion, vs...)) -} - -// PytorchVersionGT applies the GT predicate on the "pytorch_version" field. -func PytorchVersionGT(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldGT(FieldPytorchVersion, v)) -} - -// PytorchVersionGTE applies the GTE predicate on the "pytorch_version" field. -func PytorchVersionGTE(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldGTE(FieldPytorchVersion, v)) -} - -// PytorchVersionLT applies the LT predicate on the "pytorch_version" field. -func PytorchVersionLT(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldLT(FieldPytorchVersion, v)) -} - -// PytorchVersionLTE applies the LTE predicate on the "pytorch_version" field. -func PytorchVersionLTE(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldLTE(FieldPytorchVersion, v)) -} - -// PytorchVersionContains applies the Contains predicate on the "pytorch_version" field. -func PytorchVersionContains(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldContains(FieldPytorchVersion, v)) -} - -// PytorchVersionHasPrefix applies the HasPrefix predicate on the "pytorch_version" field. -func PytorchVersionHasPrefix(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldPytorchVersion, v)) -} - -// PytorchVersionHasSuffix applies the HasSuffix predicate on the "pytorch_version" field. -func PytorchVersionHasSuffix(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldPytorchVersion, v)) -} - -// PytorchVersionIsNil applies the IsNil predicate on the "pytorch_version" field. -func PytorchVersionIsNil() predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldIsNull(FieldPytorchVersion)) -} - -// PytorchVersionNotNil applies the NotNil predicate on the "pytorch_version" field. -func PytorchVersionNotNil() predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNotNull(FieldPytorchVersion)) -} - -// PytorchVersionEqualFold applies the EqualFold predicate on the "pytorch_version" field. -func PytorchVersionEqualFold(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldPytorchVersion, v)) -} - -// PytorchVersionContainsFold applies the ContainsFold predicate on the "pytorch_version" field. -func PytorchVersionContainsFold(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldPytorchVersion, v)) -} - // WorkflowNameEQ applies the EQ predicate on the "workflow_name" field. func WorkflowNameEQ(v string) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldEQ(FieldWorkflowName, v)) @@ -551,79 +433,163 @@ func RunIDContainsFold(v string) predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldRunID, v)) } +// JobIDEQ applies the EQ predicate on the "job_id" field. +func JobIDEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldJobID, v)) +} + +// JobIDNEQ applies the NEQ predicate on the "job_id" field. +func JobIDNEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldJobID, v)) +} + +// JobIDIn applies the In predicate on the "job_id" field. +func JobIDIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldJobID, vs...)) +} + +// JobIDNotIn applies the NotIn predicate on the "job_id" field. +func JobIDNotIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldJobID, vs...)) +} + +// JobIDGT applies the GT predicate on the "job_id" field. +func JobIDGT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldJobID, v)) +} + +// JobIDGTE applies the GTE predicate on the "job_id" field. +func JobIDGTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldJobID, v)) +} + +// JobIDLT applies the LT predicate on the "job_id" field. +func JobIDLT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldJobID, v)) +} + +// JobIDLTE applies the LTE predicate on the "job_id" field. +func JobIDLTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldJobID, v)) +} + +// JobIDContains applies the Contains predicate on the "job_id" field. +func JobIDContains(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContains(FieldJobID, v)) +} + +// JobIDHasPrefix applies the HasPrefix predicate on the "job_id" field. +func JobIDHasPrefix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldJobID, v)) +} + +// JobIDHasSuffix applies the HasSuffix predicate on the "job_id" field. +func JobIDHasSuffix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldJobID, v)) +} + +// JobIDIsNil applies the IsNil predicate on the "job_id" field. +func JobIDIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldJobID)) +} + +// JobIDNotNil applies the NotNil predicate on the "job_id" field. +func JobIDNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldJobID)) +} + +// JobIDEqualFold applies the EqualFold predicate on the "job_id" field. +func JobIDEqualFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldJobID, v)) +} + +// JobIDContainsFold applies the ContainsFold predicate on the "job_id" field. +func JobIDContainsFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldJobID, v)) +} + // StatusEQ applies the EQ predicate on the "status" field. -func StatusEQ(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEQ(FieldStatus, v)) +func StatusEQ(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldEQ(FieldStatus, vc)) } // StatusNEQ applies the NEQ predicate on the "status" field. -func StatusNEQ(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNEQ(FieldStatus, v)) +func StatusNEQ(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldStatus, vc)) } // StatusIn applies the In predicate on the "status" field. -func StatusIn(vs ...string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldIn(FieldStatus, vs...)) +func StatusIn(vs ...schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.CIWorkflowResult(sql.FieldIn(FieldStatus, v...)) } // StatusNotIn applies the NotIn predicate on the "status" field. -func StatusNotIn(vs ...string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNotIn(FieldStatus, vs...)) +func StatusNotIn(vs ...schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + v := make([]any, len(vs)) + for i := range v { + v[i] = string(vs[i]) + } + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldStatus, v...)) } // StatusGT applies the GT predicate on the "status" field. -func StatusGT(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldGT(FieldStatus, v)) +func StatusGT(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldGT(FieldStatus, vc)) } // StatusGTE applies the GTE predicate on the "status" field. -func StatusGTE(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldGTE(FieldStatus, v)) +func StatusGTE(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldGTE(FieldStatus, vc)) } // StatusLT applies the LT predicate on the "status" field. -func StatusLT(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldLT(FieldStatus, v)) +func StatusLT(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldLT(FieldStatus, vc)) } // StatusLTE applies the LTE predicate on the "status" field. -func StatusLTE(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldLTE(FieldStatus, v)) +func StatusLTE(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldLTE(FieldStatus, vc)) } // StatusContains applies the Contains predicate on the "status" field. -func StatusContains(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldContains(FieldStatus, v)) +func StatusContains(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldContains(FieldStatus, vc)) } // StatusHasPrefix applies the HasPrefix predicate on the "status" field. -func StatusHasPrefix(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldStatus, v)) +func StatusHasPrefix(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldStatus, vc)) } // StatusHasSuffix applies the HasSuffix predicate on the "status" field. -func StatusHasSuffix(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldStatus, v)) -} - -// StatusIsNil applies the IsNil predicate on the "status" field. -func StatusIsNil() predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldIsNull(FieldStatus)) -} - -// StatusNotNil applies the NotNil predicate on the "status" field. -func StatusNotNil() predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldNotNull(FieldStatus)) +func StatusHasSuffix(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldStatus, vc)) } // StatusEqualFold applies the EqualFold predicate on the "status" field. -func StatusEqualFold(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldStatus, v)) +func StatusEqualFold(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldStatus, vc)) } // StatusContainsFold applies the ContainsFold predicate on the "status" field. -func StatusContainsFold(v string) predicate.CIWorkflowResult { - return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldStatus, v)) +func StatusContainsFold(v schema.WorkflowRunStatusType) predicate.CIWorkflowResult { + vc := string(v) + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldStatus, vc)) } // StartTimeEQ applies the EQ predicate on the "start_time" field. @@ -726,6 +692,491 @@ func EndTimeNotNil() predicate.CIWorkflowResult { return predicate.CIWorkflowResult(sql.FieldNotNull(FieldEndTime)) } +// PythonVersionEQ applies the EQ predicate on the "python_version" field. +func PythonVersionEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldPythonVersion, v)) +} + +// PythonVersionNEQ applies the NEQ predicate on the "python_version" field. +func PythonVersionNEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldPythonVersion, v)) +} + +// PythonVersionIn applies the In predicate on the "python_version" field. +func PythonVersionIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldPythonVersion, vs...)) +} + +// PythonVersionNotIn applies the NotIn predicate on the "python_version" field. +func PythonVersionNotIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldPythonVersion, vs...)) +} + +// PythonVersionGT applies the GT predicate on the "python_version" field. +func PythonVersionGT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldPythonVersion, v)) +} + +// PythonVersionGTE applies the GTE predicate on the "python_version" field. +func PythonVersionGTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldPythonVersion, v)) +} + +// PythonVersionLT applies the LT predicate on the "python_version" field. +func PythonVersionLT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldPythonVersion, v)) +} + +// PythonVersionLTE applies the LTE predicate on the "python_version" field. +func PythonVersionLTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldPythonVersion, v)) +} + +// PythonVersionContains applies the Contains predicate on the "python_version" field. +func PythonVersionContains(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContains(FieldPythonVersion, v)) +} + +// PythonVersionHasPrefix applies the HasPrefix predicate on the "python_version" field. +func PythonVersionHasPrefix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldPythonVersion, v)) +} + +// PythonVersionHasSuffix applies the HasSuffix predicate on the "python_version" field. +func PythonVersionHasSuffix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldPythonVersion, v)) +} + +// PythonVersionIsNil applies the IsNil predicate on the "python_version" field. +func PythonVersionIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldPythonVersion)) +} + +// PythonVersionNotNil applies the NotNil predicate on the "python_version" field. +func PythonVersionNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldPythonVersion)) +} + +// PythonVersionEqualFold applies the EqualFold predicate on the "python_version" field. +func PythonVersionEqualFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldPythonVersion, v)) +} + +// PythonVersionContainsFold applies the ContainsFold predicate on the "python_version" field. +func PythonVersionContainsFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldPythonVersion, v)) +} + +// PytorchVersionEQ applies the EQ predicate on the "pytorch_version" field. +func PytorchVersionEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldPytorchVersion, v)) +} + +// PytorchVersionNEQ applies the NEQ predicate on the "pytorch_version" field. +func PytorchVersionNEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldPytorchVersion, v)) +} + +// PytorchVersionIn applies the In predicate on the "pytorch_version" field. +func PytorchVersionIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldPytorchVersion, vs...)) +} + +// PytorchVersionNotIn applies the NotIn predicate on the "pytorch_version" field. +func PytorchVersionNotIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldPytorchVersion, vs...)) +} + +// PytorchVersionGT applies the GT predicate on the "pytorch_version" field. +func PytorchVersionGT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldPytorchVersion, v)) +} + +// PytorchVersionGTE applies the GTE predicate on the "pytorch_version" field. +func PytorchVersionGTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldPytorchVersion, v)) +} + +// PytorchVersionLT applies the LT predicate on the "pytorch_version" field. +func PytorchVersionLT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldPytorchVersion, v)) +} + +// PytorchVersionLTE applies the LTE predicate on the "pytorch_version" field. +func PytorchVersionLTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldPytorchVersion, v)) +} + +// PytorchVersionContains applies the Contains predicate on the "pytorch_version" field. +func PytorchVersionContains(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContains(FieldPytorchVersion, v)) +} + +// PytorchVersionHasPrefix applies the HasPrefix predicate on the "pytorch_version" field. +func PytorchVersionHasPrefix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldPytorchVersion, v)) +} + +// PytorchVersionHasSuffix applies the HasSuffix predicate on the "pytorch_version" field. +func PytorchVersionHasSuffix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldPytorchVersion, v)) +} + +// PytorchVersionIsNil applies the IsNil predicate on the "pytorch_version" field. +func PytorchVersionIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldPytorchVersion)) +} + +// PytorchVersionNotNil applies the NotNil predicate on the "pytorch_version" field. +func PytorchVersionNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldPytorchVersion)) +} + +// PytorchVersionEqualFold applies the EqualFold predicate on the "pytorch_version" field. +func PytorchVersionEqualFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldPytorchVersion, v)) +} + +// PytorchVersionContainsFold applies the ContainsFold predicate on the "pytorch_version" field. +func PytorchVersionContainsFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldPytorchVersion, v)) +} + +// CudaVersionEQ applies the EQ predicate on the "cuda_version" field. +func CudaVersionEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldCudaVersion, v)) +} + +// CudaVersionNEQ applies the NEQ predicate on the "cuda_version" field. +func CudaVersionNEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldCudaVersion, v)) +} + +// CudaVersionIn applies the In predicate on the "cuda_version" field. +func CudaVersionIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldCudaVersion, vs...)) +} + +// CudaVersionNotIn applies the NotIn predicate on the "cuda_version" field. +func CudaVersionNotIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldCudaVersion, vs...)) +} + +// CudaVersionGT applies the GT predicate on the "cuda_version" field. +func CudaVersionGT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldCudaVersion, v)) +} + +// CudaVersionGTE applies the GTE predicate on the "cuda_version" field. +func CudaVersionGTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldCudaVersion, v)) +} + +// CudaVersionLT applies the LT predicate on the "cuda_version" field. +func CudaVersionLT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldCudaVersion, v)) +} + +// CudaVersionLTE applies the LTE predicate on the "cuda_version" field. +func CudaVersionLTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldCudaVersion, v)) +} + +// CudaVersionContains applies the Contains predicate on the "cuda_version" field. +func CudaVersionContains(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContains(FieldCudaVersion, v)) +} + +// CudaVersionHasPrefix applies the HasPrefix predicate on the "cuda_version" field. +func CudaVersionHasPrefix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldCudaVersion, v)) +} + +// CudaVersionHasSuffix applies the HasSuffix predicate on the "cuda_version" field. +func CudaVersionHasSuffix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldCudaVersion, v)) +} + +// CudaVersionIsNil applies the IsNil predicate on the "cuda_version" field. +func CudaVersionIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldCudaVersion)) +} + +// CudaVersionNotNil applies the NotNil predicate on the "cuda_version" field. +func CudaVersionNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldCudaVersion)) +} + +// CudaVersionEqualFold applies the EqualFold predicate on the "cuda_version" field. +func CudaVersionEqualFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldCudaVersion, v)) +} + +// CudaVersionContainsFold applies the ContainsFold predicate on the "cuda_version" field. +func CudaVersionContainsFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldCudaVersion, v)) +} + +// ComfyRunFlagsEQ applies the EQ predicate on the "comfy_run_flags" field. +func ComfyRunFlagsEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsNEQ applies the NEQ predicate on the "comfy_run_flags" field. +func ComfyRunFlagsNEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsIn applies the In predicate on the "comfy_run_flags" field. +func ComfyRunFlagsIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldComfyRunFlags, vs...)) +} + +// ComfyRunFlagsNotIn applies the NotIn predicate on the "comfy_run_flags" field. +func ComfyRunFlagsNotIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldComfyRunFlags, vs...)) +} + +// ComfyRunFlagsGT applies the GT predicate on the "comfy_run_flags" field. +func ComfyRunFlagsGT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsGTE applies the GTE predicate on the "comfy_run_flags" field. +func ComfyRunFlagsGTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsLT applies the LT predicate on the "comfy_run_flags" field. +func ComfyRunFlagsLT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsLTE applies the LTE predicate on the "comfy_run_flags" field. +func ComfyRunFlagsLTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsContains applies the Contains predicate on the "comfy_run_flags" field. +func ComfyRunFlagsContains(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContains(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsHasPrefix applies the HasPrefix predicate on the "comfy_run_flags" field. +func ComfyRunFlagsHasPrefix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsHasSuffix applies the HasSuffix predicate on the "comfy_run_flags" field. +func ComfyRunFlagsHasSuffix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsIsNil applies the IsNil predicate on the "comfy_run_flags" field. +func ComfyRunFlagsIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldComfyRunFlags)) +} + +// ComfyRunFlagsNotNil applies the NotNil predicate on the "comfy_run_flags" field. +func ComfyRunFlagsNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldComfyRunFlags)) +} + +// ComfyRunFlagsEqualFold applies the EqualFold predicate on the "comfy_run_flags" field. +func ComfyRunFlagsEqualFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldComfyRunFlags, v)) +} + +// ComfyRunFlagsContainsFold applies the ContainsFold predicate on the "comfy_run_flags" field. +func ComfyRunFlagsContainsFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldComfyRunFlags, v)) +} + +// AvgVramEQ applies the EQ predicate on the "avg_vram" field. +func AvgVramEQ(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldAvgVram, v)) +} + +// AvgVramNEQ applies the NEQ predicate on the "avg_vram" field. +func AvgVramNEQ(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldAvgVram, v)) +} + +// AvgVramIn applies the In predicate on the "avg_vram" field. +func AvgVramIn(vs ...int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldAvgVram, vs...)) +} + +// AvgVramNotIn applies the NotIn predicate on the "avg_vram" field. +func AvgVramNotIn(vs ...int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldAvgVram, vs...)) +} + +// AvgVramGT applies the GT predicate on the "avg_vram" field. +func AvgVramGT(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldAvgVram, v)) +} + +// AvgVramGTE applies the GTE predicate on the "avg_vram" field. +func AvgVramGTE(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldAvgVram, v)) +} + +// AvgVramLT applies the LT predicate on the "avg_vram" field. +func AvgVramLT(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldAvgVram, v)) +} + +// AvgVramLTE applies the LTE predicate on the "avg_vram" field. +func AvgVramLTE(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldAvgVram, v)) +} + +// AvgVramIsNil applies the IsNil predicate on the "avg_vram" field. +func AvgVramIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldAvgVram)) +} + +// AvgVramNotNil applies the NotNil predicate on the "avg_vram" field. +func AvgVramNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldAvgVram)) +} + +// PeakVramEQ applies the EQ predicate on the "peak_vram" field. +func PeakVramEQ(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldPeakVram, v)) +} + +// PeakVramNEQ applies the NEQ predicate on the "peak_vram" field. +func PeakVramNEQ(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldPeakVram, v)) +} + +// PeakVramIn applies the In predicate on the "peak_vram" field. +func PeakVramIn(vs ...int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldPeakVram, vs...)) +} + +// PeakVramNotIn applies the NotIn predicate on the "peak_vram" field. +func PeakVramNotIn(vs ...int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldPeakVram, vs...)) +} + +// PeakVramGT applies the GT predicate on the "peak_vram" field. +func PeakVramGT(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldPeakVram, v)) +} + +// PeakVramGTE applies the GTE predicate on the "peak_vram" field. +func PeakVramGTE(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldPeakVram, v)) +} + +// PeakVramLT applies the LT predicate on the "peak_vram" field. +func PeakVramLT(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldPeakVram, v)) +} + +// PeakVramLTE applies the LTE predicate on the "peak_vram" field. +func PeakVramLTE(v int) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldPeakVram, v)) +} + +// PeakVramIsNil applies the IsNil predicate on the "peak_vram" field. +func PeakVramIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldPeakVram)) +} + +// PeakVramNotNil applies the NotNil predicate on the "peak_vram" field. +func PeakVramNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldPeakVram)) +} + +// JobTriggerUserEQ applies the EQ predicate on the "job_trigger_user" field. +func JobTriggerUserEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEQ(FieldJobTriggerUser, v)) +} + +// JobTriggerUserNEQ applies the NEQ predicate on the "job_trigger_user" field. +func JobTriggerUserNEQ(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNEQ(FieldJobTriggerUser, v)) +} + +// JobTriggerUserIn applies the In predicate on the "job_trigger_user" field. +func JobTriggerUserIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIn(FieldJobTriggerUser, vs...)) +} + +// JobTriggerUserNotIn applies the NotIn predicate on the "job_trigger_user" field. +func JobTriggerUserNotIn(vs ...string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotIn(FieldJobTriggerUser, vs...)) +} + +// JobTriggerUserGT applies the GT predicate on the "job_trigger_user" field. +func JobTriggerUserGT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGT(FieldJobTriggerUser, v)) +} + +// JobTriggerUserGTE applies the GTE predicate on the "job_trigger_user" field. +func JobTriggerUserGTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldGTE(FieldJobTriggerUser, v)) +} + +// JobTriggerUserLT applies the LT predicate on the "job_trigger_user" field. +func JobTriggerUserLT(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLT(FieldJobTriggerUser, v)) +} + +// JobTriggerUserLTE applies the LTE predicate on the "job_trigger_user" field. +func JobTriggerUserLTE(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldLTE(FieldJobTriggerUser, v)) +} + +// JobTriggerUserContains applies the Contains predicate on the "job_trigger_user" field. +func JobTriggerUserContains(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContains(FieldJobTriggerUser, v)) +} + +// JobTriggerUserHasPrefix applies the HasPrefix predicate on the "job_trigger_user" field. +func JobTriggerUserHasPrefix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasPrefix(FieldJobTriggerUser, v)) +} + +// JobTriggerUserHasSuffix applies the HasSuffix predicate on the "job_trigger_user" field. +func JobTriggerUserHasSuffix(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldHasSuffix(FieldJobTriggerUser, v)) +} + +// JobTriggerUserIsNil applies the IsNil predicate on the "job_trigger_user" field. +func JobTriggerUserIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldJobTriggerUser)) +} + +// JobTriggerUserNotNil applies the NotNil predicate on the "job_trigger_user" field. +func JobTriggerUserNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldJobTriggerUser)) +} + +// JobTriggerUserEqualFold applies the EqualFold predicate on the "job_trigger_user" field. +func JobTriggerUserEqualFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldEqualFold(FieldJobTriggerUser, v)) +} + +// JobTriggerUserContainsFold applies the ContainsFold predicate on the "job_trigger_user" field. +func JobTriggerUserContainsFold(v string) predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldContainsFold(FieldJobTriggerUser, v)) +} + +// MetadataIsNil applies the IsNil predicate on the "metadata" field. +func MetadataIsNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldIsNull(FieldMetadata)) +} + +// MetadataNotNil applies the NotNil predicate on the "metadata" field. +func MetadataNotNil() predicate.CIWorkflowResult { + return predicate.CIWorkflowResult(sql.FieldNotNull(FieldMetadata)) +} + // HasGitcommit applies the HasEdge predicate on the "gitcommit" edge. func HasGitcommit() predicate.CIWorkflowResult { return predicate.CIWorkflowResult(func(s *sql.Selector) { @@ -754,7 +1205,7 @@ func HasStorageFile() predicate.CIWorkflowResult { return predicate.CIWorkflowResult(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.Edge(sqlgraph.M2O, false, StorageFileTable, StorageFileColumn), + sqlgraph.Edge(sqlgraph.O2M, false, StorageFileTable, StorageFileColumn), ) sqlgraph.HasNeighbors(s, step) }) diff --git a/ent/ciworkflowresult_create.go b/ent/ciworkflowresult_create.go index 2c9bc6c..84b0463 100644 --- a/ent/ciworkflowresult_create.go +++ b/ent/ciworkflowresult_create.go @@ -8,6 +8,7 @@ import ( "fmt" "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" + "registry-backend/ent/schema" "registry-backend/ent/storagefile" "time" @@ -60,34 +61,6 @@ func (cwrc *CIWorkflowResultCreate) SetOperatingSystem(s string) *CIWorkflowResu return cwrc } -// SetGpuType sets the "gpu_type" field. -func (cwrc *CIWorkflowResultCreate) SetGpuType(s string) *CIWorkflowResultCreate { - cwrc.mutation.SetGpuType(s) - return cwrc -} - -// SetNillableGpuType sets the "gpu_type" field if the given value is not nil. -func (cwrc *CIWorkflowResultCreate) SetNillableGpuType(s *string) *CIWorkflowResultCreate { - if s != nil { - cwrc.SetGpuType(*s) - } - return cwrc -} - -// SetPytorchVersion sets the "pytorch_version" field. -func (cwrc *CIWorkflowResultCreate) SetPytorchVersion(s string) *CIWorkflowResultCreate { - cwrc.mutation.SetPytorchVersion(s) - return cwrc -} - -// SetNillablePytorchVersion sets the "pytorch_version" field if the given value is not nil. -func (cwrc *CIWorkflowResultCreate) SetNillablePytorchVersion(s *string) *CIWorkflowResultCreate { - if s != nil { - cwrc.SetPytorchVersion(*s) - } - return cwrc -} - // SetWorkflowName sets the "workflow_name" field. func (cwrc *CIWorkflowResultCreate) SetWorkflowName(s string) *CIWorkflowResultCreate { cwrc.mutation.SetWorkflowName(s) @@ -116,16 +89,30 @@ func (cwrc *CIWorkflowResultCreate) SetNillableRunID(s *string) *CIWorkflowResul return cwrc } +// SetJobID sets the "job_id" field. +func (cwrc *CIWorkflowResultCreate) SetJobID(s string) *CIWorkflowResultCreate { + cwrc.mutation.SetJobID(s) + return cwrc +} + +// SetNillableJobID sets the "job_id" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillableJobID(s *string) *CIWorkflowResultCreate { + if s != nil { + cwrc.SetJobID(*s) + } + return cwrc +} + // SetStatus sets the "status" field. -func (cwrc *CIWorkflowResultCreate) SetStatus(s string) *CIWorkflowResultCreate { - cwrc.mutation.SetStatus(s) +func (cwrc *CIWorkflowResultCreate) SetStatus(srst schema.WorkflowRunStatusType) *CIWorkflowResultCreate { + cwrc.mutation.SetStatus(srst) return cwrc } // SetNillableStatus sets the "status" field if the given value is not nil. -func (cwrc *CIWorkflowResultCreate) SetNillableStatus(s *string) *CIWorkflowResultCreate { - if s != nil { - cwrc.SetStatus(*s) +func (cwrc *CIWorkflowResultCreate) SetNillableStatus(srst *schema.WorkflowRunStatusType) *CIWorkflowResultCreate { + if srst != nil { + cwrc.SetStatus(*srst) } return cwrc } @@ -158,6 +145,110 @@ func (cwrc *CIWorkflowResultCreate) SetNillableEndTime(i *int64) *CIWorkflowResu return cwrc } +// SetPythonVersion sets the "python_version" field. +func (cwrc *CIWorkflowResultCreate) SetPythonVersion(s string) *CIWorkflowResultCreate { + cwrc.mutation.SetPythonVersion(s) + return cwrc +} + +// SetNillablePythonVersion sets the "python_version" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillablePythonVersion(s *string) *CIWorkflowResultCreate { + if s != nil { + cwrc.SetPythonVersion(*s) + } + return cwrc +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (cwrc *CIWorkflowResultCreate) SetPytorchVersion(s string) *CIWorkflowResultCreate { + cwrc.mutation.SetPytorchVersion(s) + return cwrc +} + +// SetNillablePytorchVersion sets the "pytorch_version" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillablePytorchVersion(s *string) *CIWorkflowResultCreate { + if s != nil { + cwrc.SetPytorchVersion(*s) + } + return cwrc +} + +// SetCudaVersion sets the "cuda_version" field. +func (cwrc *CIWorkflowResultCreate) SetCudaVersion(s string) *CIWorkflowResultCreate { + cwrc.mutation.SetCudaVersion(s) + return cwrc +} + +// SetNillableCudaVersion sets the "cuda_version" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillableCudaVersion(s *string) *CIWorkflowResultCreate { + if s != nil { + cwrc.SetCudaVersion(*s) + } + return cwrc +} + +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (cwrc *CIWorkflowResultCreate) SetComfyRunFlags(s string) *CIWorkflowResultCreate { + cwrc.mutation.SetComfyRunFlags(s) + return cwrc +} + +// SetNillableComfyRunFlags sets the "comfy_run_flags" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillableComfyRunFlags(s *string) *CIWorkflowResultCreate { + if s != nil { + cwrc.SetComfyRunFlags(*s) + } + return cwrc +} + +// SetAvgVram sets the "avg_vram" field. +func (cwrc *CIWorkflowResultCreate) SetAvgVram(i int) *CIWorkflowResultCreate { + cwrc.mutation.SetAvgVram(i) + return cwrc +} + +// SetNillableAvgVram sets the "avg_vram" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillableAvgVram(i *int) *CIWorkflowResultCreate { + if i != nil { + cwrc.SetAvgVram(*i) + } + return cwrc +} + +// SetPeakVram sets the "peak_vram" field. +func (cwrc *CIWorkflowResultCreate) SetPeakVram(i int) *CIWorkflowResultCreate { + cwrc.mutation.SetPeakVram(i) + return cwrc +} + +// SetNillablePeakVram sets the "peak_vram" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillablePeakVram(i *int) *CIWorkflowResultCreate { + if i != nil { + cwrc.SetPeakVram(*i) + } + return cwrc +} + +// SetJobTriggerUser sets the "job_trigger_user" field. +func (cwrc *CIWorkflowResultCreate) SetJobTriggerUser(s string) *CIWorkflowResultCreate { + cwrc.mutation.SetJobTriggerUser(s) + return cwrc +} + +// SetNillableJobTriggerUser sets the "job_trigger_user" field if the given value is not nil. +func (cwrc *CIWorkflowResultCreate) SetNillableJobTriggerUser(s *string) *CIWorkflowResultCreate { + if s != nil { + cwrc.SetJobTriggerUser(*s) + } + return cwrc +} + +// SetMetadata sets the "metadata" field. +func (cwrc *CIWorkflowResultCreate) SetMetadata(m map[string]interface{}) *CIWorkflowResultCreate { + cwrc.mutation.SetMetadata(m) + return cwrc +} + // SetID sets the "id" field. func (cwrc *CIWorkflowResultCreate) SetID(u uuid.UUID) *CIWorkflowResultCreate { cwrc.mutation.SetID(u) @@ -191,23 +282,19 @@ func (cwrc *CIWorkflowResultCreate) SetGitcommit(g *GitCommit) *CIWorkflowResult return cwrc.SetGitcommitID(g.ID) } -// SetStorageFileID sets the "storage_file" edge to the StorageFile entity by ID. -func (cwrc *CIWorkflowResultCreate) SetStorageFileID(id uuid.UUID) *CIWorkflowResultCreate { - cwrc.mutation.SetStorageFileID(id) +// AddStorageFileIDs adds the "storage_file" edge to the StorageFile entity by IDs. +func (cwrc *CIWorkflowResultCreate) AddStorageFileIDs(ids ...uuid.UUID) *CIWorkflowResultCreate { + cwrc.mutation.AddStorageFileIDs(ids...) return cwrc } -// SetNillableStorageFileID sets the "storage_file" edge to the StorageFile entity by ID if the given value is not nil. -func (cwrc *CIWorkflowResultCreate) SetNillableStorageFileID(id *uuid.UUID) *CIWorkflowResultCreate { - if id != nil { - cwrc = cwrc.SetStorageFileID(*id) +// AddStorageFile adds the "storage_file" edges to the StorageFile entity. +func (cwrc *CIWorkflowResultCreate) AddStorageFile(s ...*StorageFile) *CIWorkflowResultCreate { + ids := make([]uuid.UUID, len(s)) + for i := range s { + ids[i] = s[i].ID } - return cwrc -} - -// SetStorageFile sets the "storage_file" edge to the StorageFile entity. -func (cwrc *CIWorkflowResultCreate) SetStorageFile(s *StorageFile) *CIWorkflowResultCreate { - return cwrc.SetStorageFileID(s.ID) + return cwrc.AddStorageFileIDs(ids...) } // Mutation returns the CIWorkflowResultMutation object of the builder. @@ -253,6 +340,10 @@ func (cwrc *CIWorkflowResultCreate) defaults() { v := ciworkflowresult.DefaultUpdateTime() cwrc.mutation.SetUpdateTime(v) } + if _, ok := cwrc.mutation.Status(); !ok { + v := ciworkflowresult.DefaultStatus + cwrc.mutation.SetStatus(v) + } if _, ok := cwrc.mutation.ID(); !ok { v := ciworkflowresult.DefaultID() cwrc.mutation.SetID(v) @@ -270,6 +361,9 @@ func (cwrc *CIWorkflowResultCreate) check() error { if _, ok := cwrc.mutation.OperatingSystem(); !ok { return &ValidationError{Name: "operating_system", err: errors.New(`ent: missing required field "CIWorkflowResult.operating_system"`)} } + if _, ok := cwrc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "CIWorkflowResult.status"`)} + } return nil } @@ -318,14 +412,6 @@ func (cwrc *CIWorkflowResultCreate) createSpec() (*CIWorkflowResult, *sqlgraph.C _spec.SetField(ciworkflowresult.FieldOperatingSystem, field.TypeString, value) _node.OperatingSystem = value } - if value, ok := cwrc.mutation.GpuType(); ok { - _spec.SetField(ciworkflowresult.FieldGpuType, field.TypeString, value) - _node.GpuType = value - } - if value, ok := cwrc.mutation.PytorchVersion(); ok { - _spec.SetField(ciworkflowresult.FieldPytorchVersion, field.TypeString, value) - _node.PytorchVersion = value - } if value, ok := cwrc.mutation.WorkflowName(); ok { _spec.SetField(ciworkflowresult.FieldWorkflowName, field.TypeString, value) _node.WorkflowName = value @@ -334,6 +420,10 @@ func (cwrc *CIWorkflowResultCreate) createSpec() (*CIWorkflowResult, *sqlgraph.C _spec.SetField(ciworkflowresult.FieldRunID, field.TypeString, value) _node.RunID = value } + if value, ok := cwrc.mutation.JobID(); ok { + _spec.SetField(ciworkflowresult.FieldJobID, field.TypeString, value) + _node.JobID = value + } if value, ok := cwrc.mutation.Status(); ok { _spec.SetField(ciworkflowresult.FieldStatus, field.TypeString, value) _node.Status = value @@ -346,6 +436,38 @@ func (cwrc *CIWorkflowResultCreate) createSpec() (*CIWorkflowResult, *sqlgraph.C _spec.SetField(ciworkflowresult.FieldEndTime, field.TypeInt64, value) _node.EndTime = value } + if value, ok := cwrc.mutation.PythonVersion(); ok { + _spec.SetField(ciworkflowresult.FieldPythonVersion, field.TypeString, value) + _node.PythonVersion = value + } + if value, ok := cwrc.mutation.PytorchVersion(); ok { + _spec.SetField(ciworkflowresult.FieldPytorchVersion, field.TypeString, value) + _node.PytorchVersion = value + } + if value, ok := cwrc.mutation.CudaVersion(); ok { + _spec.SetField(ciworkflowresult.FieldCudaVersion, field.TypeString, value) + _node.CudaVersion = value + } + if value, ok := cwrc.mutation.ComfyRunFlags(); ok { + _spec.SetField(ciworkflowresult.FieldComfyRunFlags, field.TypeString, value) + _node.ComfyRunFlags = value + } + if value, ok := cwrc.mutation.AvgVram(); ok { + _spec.SetField(ciworkflowresult.FieldAvgVram, field.TypeInt, value) + _node.AvgVram = value + } + if value, ok := cwrc.mutation.PeakVram(); ok { + _spec.SetField(ciworkflowresult.FieldPeakVram, field.TypeInt, value) + _node.PeakVram = value + } + if value, ok := cwrc.mutation.JobTriggerUser(); ok { + _spec.SetField(ciworkflowresult.FieldJobTriggerUser, field.TypeString, value) + _node.JobTriggerUser = value + } + if value, ok := cwrc.mutation.Metadata(); ok { + _spec.SetField(ciworkflowresult.FieldMetadata, field.TypeJSON, value) + _node.Metadata = value + } if nodes := cwrc.mutation.GitcommitIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -365,7 +487,7 @@ func (cwrc *CIWorkflowResultCreate) createSpec() (*CIWorkflowResult, *sqlgraph.C } if nodes := cwrc.mutation.StorageFileIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.O2M, Inverse: false, Table: ciworkflowresult.StorageFileTable, Columns: []string{ciworkflowresult.StorageFileColumn}, @@ -377,7 +499,6 @@ func (cwrc *CIWorkflowResultCreate) createSpec() (*CIWorkflowResult, *sqlgraph.C for _, k := range nodes { edge.Target.Nodes = append(edge.Target.Nodes, k) } - _node.ci_workflow_result_storage_file = &nodes[0] _spec.Edges = append(_spec.Edges, edge) } return _node, _spec @@ -456,42 +577,6 @@ func (u *CIWorkflowResultUpsert) UpdateOperatingSystem() *CIWorkflowResultUpsert return u } -// SetGpuType sets the "gpu_type" field. -func (u *CIWorkflowResultUpsert) SetGpuType(v string) *CIWorkflowResultUpsert { - u.Set(ciworkflowresult.FieldGpuType, v) - return u -} - -// UpdateGpuType sets the "gpu_type" field to the value that was provided on create. -func (u *CIWorkflowResultUpsert) UpdateGpuType() *CIWorkflowResultUpsert { - u.SetExcluded(ciworkflowresult.FieldGpuType) - return u -} - -// ClearGpuType clears the value of the "gpu_type" field. -func (u *CIWorkflowResultUpsert) ClearGpuType() *CIWorkflowResultUpsert { - u.SetNull(ciworkflowresult.FieldGpuType) - return u -} - -// SetPytorchVersion sets the "pytorch_version" field. -func (u *CIWorkflowResultUpsert) SetPytorchVersion(v string) *CIWorkflowResultUpsert { - u.Set(ciworkflowresult.FieldPytorchVersion, v) - return u -} - -// UpdatePytorchVersion sets the "pytorch_version" field to the value that was provided on create. -func (u *CIWorkflowResultUpsert) UpdatePytorchVersion() *CIWorkflowResultUpsert { - u.SetExcluded(ciworkflowresult.FieldPytorchVersion) - return u -} - -// ClearPytorchVersion clears the value of the "pytorch_version" field. -func (u *CIWorkflowResultUpsert) ClearPytorchVersion() *CIWorkflowResultUpsert { - u.SetNull(ciworkflowresult.FieldPytorchVersion) - return u -} - // SetWorkflowName sets the "workflow_name" field. func (u *CIWorkflowResultUpsert) SetWorkflowName(v string) *CIWorkflowResultUpsert { u.Set(ciworkflowresult.FieldWorkflowName, v) @@ -528,8 +613,26 @@ func (u *CIWorkflowResultUpsert) ClearRunID() *CIWorkflowResultUpsert { return u } +// SetJobID sets the "job_id" field. +func (u *CIWorkflowResultUpsert) SetJobID(v string) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldJobID, v) + return u +} + +// UpdateJobID sets the "job_id" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdateJobID() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldJobID) + return u +} + +// ClearJobID clears the value of the "job_id" field. +func (u *CIWorkflowResultUpsert) ClearJobID() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldJobID) + return u +} + // SetStatus sets the "status" field. -func (u *CIWorkflowResultUpsert) SetStatus(v string) *CIWorkflowResultUpsert { +func (u *CIWorkflowResultUpsert) SetStatus(v schema.WorkflowRunStatusType) *CIWorkflowResultUpsert { u.Set(ciworkflowresult.FieldStatus, v) return u } @@ -540,12 +643,6 @@ func (u *CIWorkflowResultUpsert) UpdateStatus() *CIWorkflowResultUpsert { return u } -// ClearStatus clears the value of the "status" field. -func (u *CIWorkflowResultUpsert) ClearStatus() *CIWorkflowResultUpsert { - u.SetNull(ciworkflowresult.FieldStatus) - return u -} - // SetStartTime sets the "start_time" field. func (u *CIWorkflowResultUpsert) SetStartTime(v int64) *CIWorkflowResultUpsert { u.Set(ciworkflowresult.FieldStartTime, v) @@ -594,6 +691,162 @@ func (u *CIWorkflowResultUpsert) ClearEndTime() *CIWorkflowResultUpsert { return u } +// SetPythonVersion sets the "python_version" field. +func (u *CIWorkflowResultUpsert) SetPythonVersion(v string) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldPythonVersion, v) + return u +} + +// UpdatePythonVersion sets the "python_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdatePythonVersion() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldPythonVersion) + return u +} + +// ClearPythonVersion clears the value of the "python_version" field. +func (u *CIWorkflowResultUpsert) ClearPythonVersion() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldPythonVersion) + return u +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (u *CIWorkflowResultUpsert) SetPytorchVersion(v string) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldPytorchVersion, v) + return u +} + +// UpdatePytorchVersion sets the "pytorch_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdatePytorchVersion() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldPytorchVersion) + return u +} + +// ClearPytorchVersion clears the value of the "pytorch_version" field. +func (u *CIWorkflowResultUpsert) ClearPytorchVersion() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldPytorchVersion) + return u +} + +// SetCudaVersion sets the "cuda_version" field. +func (u *CIWorkflowResultUpsert) SetCudaVersion(v string) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldCudaVersion, v) + return u +} + +// UpdateCudaVersion sets the "cuda_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdateCudaVersion() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldCudaVersion) + return u +} + +// ClearCudaVersion clears the value of the "cuda_version" field. +func (u *CIWorkflowResultUpsert) ClearCudaVersion() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldCudaVersion) + return u +} + +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (u *CIWorkflowResultUpsert) SetComfyRunFlags(v string) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldComfyRunFlags, v) + return u +} + +// UpdateComfyRunFlags sets the "comfy_run_flags" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdateComfyRunFlags() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldComfyRunFlags) + return u +} + +// ClearComfyRunFlags clears the value of the "comfy_run_flags" field. +func (u *CIWorkflowResultUpsert) ClearComfyRunFlags() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldComfyRunFlags) + return u +} + +// SetAvgVram sets the "avg_vram" field. +func (u *CIWorkflowResultUpsert) SetAvgVram(v int) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldAvgVram, v) + return u +} + +// UpdateAvgVram sets the "avg_vram" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdateAvgVram() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldAvgVram) + return u +} + +// AddAvgVram adds v to the "avg_vram" field. +func (u *CIWorkflowResultUpsert) AddAvgVram(v int) *CIWorkflowResultUpsert { + u.Add(ciworkflowresult.FieldAvgVram, v) + return u +} + +// ClearAvgVram clears the value of the "avg_vram" field. +func (u *CIWorkflowResultUpsert) ClearAvgVram() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldAvgVram) + return u +} + +// SetPeakVram sets the "peak_vram" field. +func (u *CIWorkflowResultUpsert) SetPeakVram(v int) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldPeakVram, v) + return u +} + +// UpdatePeakVram sets the "peak_vram" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdatePeakVram() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldPeakVram) + return u +} + +// AddPeakVram adds v to the "peak_vram" field. +func (u *CIWorkflowResultUpsert) AddPeakVram(v int) *CIWorkflowResultUpsert { + u.Add(ciworkflowresult.FieldPeakVram, v) + return u +} + +// ClearPeakVram clears the value of the "peak_vram" field. +func (u *CIWorkflowResultUpsert) ClearPeakVram() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldPeakVram) + return u +} + +// SetJobTriggerUser sets the "job_trigger_user" field. +func (u *CIWorkflowResultUpsert) SetJobTriggerUser(v string) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldJobTriggerUser, v) + return u +} + +// UpdateJobTriggerUser sets the "job_trigger_user" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdateJobTriggerUser() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldJobTriggerUser) + return u +} + +// ClearJobTriggerUser clears the value of the "job_trigger_user" field. +func (u *CIWorkflowResultUpsert) ClearJobTriggerUser() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldJobTriggerUser) + return u +} + +// SetMetadata sets the "metadata" field. +func (u *CIWorkflowResultUpsert) SetMetadata(v map[string]interface{}) *CIWorkflowResultUpsert { + u.Set(ciworkflowresult.FieldMetadata, v) + return u +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *CIWorkflowResultUpsert) UpdateMetadata() *CIWorkflowResultUpsert { + u.SetExcluded(ciworkflowresult.FieldMetadata) + return u +} + +// ClearMetadata clears the value of the "metadata" field. +func (u *CIWorkflowResultUpsert) ClearMetadata() *CIWorkflowResultUpsert { + u.SetNull(ciworkflowresult.FieldMetadata) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -673,92 +926,71 @@ func (u *CIWorkflowResultUpsertOne) UpdateOperatingSystem() *CIWorkflowResultUps }) } -// SetGpuType sets the "gpu_type" field. -func (u *CIWorkflowResultUpsertOne) SetGpuType(v string) *CIWorkflowResultUpsertOne { +// SetWorkflowName sets the "workflow_name" field. +func (u *CIWorkflowResultUpsertOne) SetWorkflowName(v string) *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.SetGpuType(v) + s.SetWorkflowName(v) }) } -// UpdateGpuType sets the "gpu_type" field to the value that was provided on create. -func (u *CIWorkflowResultUpsertOne) UpdateGpuType() *CIWorkflowResultUpsertOne { +// UpdateWorkflowName sets the "workflow_name" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateWorkflowName() *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.UpdateGpuType() + s.UpdateWorkflowName() }) } -// ClearGpuType clears the value of the "gpu_type" field. -func (u *CIWorkflowResultUpsertOne) ClearGpuType() *CIWorkflowResultUpsertOne { +// ClearWorkflowName clears the value of the "workflow_name" field. +func (u *CIWorkflowResultUpsertOne) ClearWorkflowName() *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearGpuType() + s.ClearWorkflowName() }) } -// SetPytorchVersion sets the "pytorch_version" field. -func (u *CIWorkflowResultUpsertOne) SetPytorchVersion(v string) *CIWorkflowResultUpsertOne { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.SetPytorchVersion(v) - }) -} - -// UpdatePytorchVersion sets the "pytorch_version" field to the value that was provided on create. -func (u *CIWorkflowResultUpsertOne) UpdatePytorchVersion() *CIWorkflowResultUpsertOne { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.UpdatePytorchVersion() - }) -} - -// ClearPytorchVersion clears the value of the "pytorch_version" field. -func (u *CIWorkflowResultUpsertOne) ClearPytorchVersion() *CIWorkflowResultUpsertOne { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearPytorchVersion() - }) -} - -// SetWorkflowName sets the "workflow_name" field. -func (u *CIWorkflowResultUpsertOne) SetWorkflowName(v string) *CIWorkflowResultUpsertOne { +// SetRunID sets the "run_id" field. +func (u *CIWorkflowResultUpsertOne) SetRunID(v string) *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.SetWorkflowName(v) + s.SetRunID(v) }) } -// UpdateWorkflowName sets the "workflow_name" field to the value that was provided on create. -func (u *CIWorkflowResultUpsertOne) UpdateWorkflowName() *CIWorkflowResultUpsertOne { +// UpdateRunID sets the "run_id" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateRunID() *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.UpdateWorkflowName() + s.UpdateRunID() }) } -// ClearWorkflowName clears the value of the "workflow_name" field. -func (u *CIWorkflowResultUpsertOne) ClearWorkflowName() *CIWorkflowResultUpsertOne { +// ClearRunID clears the value of the "run_id" field. +func (u *CIWorkflowResultUpsertOne) ClearRunID() *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearWorkflowName() + s.ClearRunID() }) } -// SetRunID sets the "run_id" field. -func (u *CIWorkflowResultUpsertOne) SetRunID(v string) *CIWorkflowResultUpsertOne { +// SetJobID sets the "job_id" field. +func (u *CIWorkflowResultUpsertOne) SetJobID(v string) *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.SetRunID(v) + s.SetJobID(v) }) } -// UpdateRunID sets the "run_id" field to the value that was provided on create. -func (u *CIWorkflowResultUpsertOne) UpdateRunID() *CIWorkflowResultUpsertOne { +// UpdateJobID sets the "job_id" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateJobID() *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.UpdateRunID() + s.UpdateJobID() }) } -// ClearRunID clears the value of the "run_id" field. -func (u *CIWorkflowResultUpsertOne) ClearRunID() *CIWorkflowResultUpsertOne { +// ClearJobID clears the value of the "job_id" field. +func (u *CIWorkflowResultUpsertOne) ClearJobID() *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearRunID() + s.ClearJobID() }) } // SetStatus sets the "status" field. -func (u *CIWorkflowResultUpsertOne) SetStatus(v string) *CIWorkflowResultUpsertOne { +func (u *CIWorkflowResultUpsertOne) SetStatus(v schema.WorkflowRunStatusType) *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { s.SetStatus(v) }) @@ -771,13 +1003,6 @@ func (u *CIWorkflowResultUpsertOne) UpdateStatus() *CIWorkflowResultUpsertOne { }) } -// ClearStatus clears the value of the "status" field. -func (u *CIWorkflowResultUpsertOne) ClearStatus() *CIWorkflowResultUpsertOne { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearStatus() - }) -} - // SetStartTime sets the "start_time" field. func (u *CIWorkflowResultUpsertOne) SetStartTime(v int64) *CIWorkflowResultUpsertOne { return u.Update(func(s *CIWorkflowResultUpsert) { @@ -834,6 +1059,188 @@ func (u *CIWorkflowResultUpsertOne) ClearEndTime() *CIWorkflowResultUpsertOne { }) } +// SetPythonVersion sets the "python_version" field. +func (u *CIWorkflowResultUpsertOne) SetPythonVersion(v string) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetPythonVersion(v) + }) +} + +// UpdatePythonVersion sets the "python_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdatePythonVersion() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdatePythonVersion() + }) +} + +// ClearPythonVersion clears the value of the "python_version" field. +func (u *CIWorkflowResultUpsertOne) ClearPythonVersion() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearPythonVersion() + }) +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (u *CIWorkflowResultUpsertOne) SetPytorchVersion(v string) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetPytorchVersion(v) + }) +} + +// UpdatePytorchVersion sets the "pytorch_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdatePytorchVersion() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdatePytorchVersion() + }) +} + +// ClearPytorchVersion clears the value of the "pytorch_version" field. +func (u *CIWorkflowResultUpsertOne) ClearPytorchVersion() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearPytorchVersion() + }) +} + +// SetCudaVersion sets the "cuda_version" field. +func (u *CIWorkflowResultUpsertOne) SetCudaVersion(v string) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetCudaVersion(v) + }) +} + +// UpdateCudaVersion sets the "cuda_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateCudaVersion() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateCudaVersion() + }) +} + +// ClearCudaVersion clears the value of the "cuda_version" field. +func (u *CIWorkflowResultUpsertOne) ClearCudaVersion() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearCudaVersion() + }) +} + +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (u *CIWorkflowResultUpsertOne) SetComfyRunFlags(v string) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetComfyRunFlags(v) + }) +} + +// UpdateComfyRunFlags sets the "comfy_run_flags" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateComfyRunFlags() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateComfyRunFlags() + }) +} + +// ClearComfyRunFlags clears the value of the "comfy_run_flags" field. +func (u *CIWorkflowResultUpsertOne) ClearComfyRunFlags() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearComfyRunFlags() + }) +} + +// SetAvgVram sets the "avg_vram" field. +func (u *CIWorkflowResultUpsertOne) SetAvgVram(v int) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetAvgVram(v) + }) +} + +// AddAvgVram adds v to the "avg_vram" field. +func (u *CIWorkflowResultUpsertOne) AddAvgVram(v int) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.AddAvgVram(v) + }) +} + +// UpdateAvgVram sets the "avg_vram" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateAvgVram() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateAvgVram() + }) +} + +// ClearAvgVram clears the value of the "avg_vram" field. +func (u *CIWorkflowResultUpsertOne) ClearAvgVram() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearAvgVram() + }) +} + +// SetPeakVram sets the "peak_vram" field. +func (u *CIWorkflowResultUpsertOne) SetPeakVram(v int) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetPeakVram(v) + }) +} + +// AddPeakVram adds v to the "peak_vram" field. +func (u *CIWorkflowResultUpsertOne) AddPeakVram(v int) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.AddPeakVram(v) + }) +} + +// UpdatePeakVram sets the "peak_vram" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdatePeakVram() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdatePeakVram() + }) +} + +// ClearPeakVram clears the value of the "peak_vram" field. +func (u *CIWorkflowResultUpsertOne) ClearPeakVram() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearPeakVram() + }) +} + +// SetJobTriggerUser sets the "job_trigger_user" field. +func (u *CIWorkflowResultUpsertOne) SetJobTriggerUser(v string) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetJobTriggerUser(v) + }) +} + +// UpdateJobTriggerUser sets the "job_trigger_user" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateJobTriggerUser() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateJobTriggerUser() + }) +} + +// ClearJobTriggerUser clears the value of the "job_trigger_user" field. +func (u *CIWorkflowResultUpsertOne) ClearJobTriggerUser() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearJobTriggerUser() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *CIWorkflowResultUpsertOne) SetMetadata(v map[string]interface{}) *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertOne) UpdateMetadata() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateMetadata() + }) +} + +// ClearMetadata clears the value of the "metadata" field. +func (u *CIWorkflowResultUpsertOne) ClearMetadata() *CIWorkflowResultUpsertOne { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearMetadata() + }) +} + // Exec executes the query. func (u *CIWorkflowResultUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1080,48 +1487,6 @@ func (u *CIWorkflowResultUpsertBulk) UpdateOperatingSystem() *CIWorkflowResultUp }) } -// SetGpuType sets the "gpu_type" field. -func (u *CIWorkflowResultUpsertBulk) SetGpuType(v string) *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.SetGpuType(v) - }) -} - -// UpdateGpuType sets the "gpu_type" field to the value that was provided on create. -func (u *CIWorkflowResultUpsertBulk) UpdateGpuType() *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.UpdateGpuType() - }) -} - -// ClearGpuType clears the value of the "gpu_type" field. -func (u *CIWorkflowResultUpsertBulk) ClearGpuType() *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearGpuType() - }) -} - -// SetPytorchVersion sets the "pytorch_version" field. -func (u *CIWorkflowResultUpsertBulk) SetPytorchVersion(v string) *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.SetPytorchVersion(v) - }) -} - -// UpdatePytorchVersion sets the "pytorch_version" field to the value that was provided on create. -func (u *CIWorkflowResultUpsertBulk) UpdatePytorchVersion() *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.UpdatePytorchVersion() - }) -} - -// ClearPytorchVersion clears the value of the "pytorch_version" field. -func (u *CIWorkflowResultUpsertBulk) ClearPytorchVersion() *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearPytorchVersion() - }) -} - // SetWorkflowName sets the "workflow_name" field. func (u *CIWorkflowResultUpsertBulk) SetWorkflowName(v string) *CIWorkflowResultUpsertBulk { return u.Update(func(s *CIWorkflowResultUpsert) { @@ -1164,8 +1529,29 @@ func (u *CIWorkflowResultUpsertBulk) ClearRunID() *CIWorkflowResultUpsertBulk { }) } +// SetJobID sets the "job_id" field. +func (u *CIWorkflowResultUpsertBulk) SetJobID(v string) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetJobID(v) + }) +} + +// UpdateJobID sets the "job_id" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdateJobID() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateJobID() + }) +} + +// ClearJobID clears the value of the "job_id" field. +func (u *CIWorkflowResultUpsertBulk) ClearJobID() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearJobID() + }) +} + // SetStatus sets the "status" field. -func (u *CIWorkflowResultUpsertBulk) SetStatus(v string) *CIWorkflowResultUpsertBulk { +func (u *CIWorkflowResultUpsertBulk) SetStatus(v schema.WorkflowRunStatusType) *CIWorkflowResultUpsertBulk { return u.Update(func(s *CIWorkflowResultUpsert) { s.SetStatus(v) }) @@ -1178,13 +1564,6 @@ func (u *CIWorkflowResultUpsertBulk) UpdateStatus() *CIWorkflowResultUpsertBulk }) } -// ClearStatus clears the value of the "status" field. -func (u *CIWorkflowResultUpsertBulk) ClearStatus() *CIWorkflowResultUpsertBulk { - return u.Update(func(s *CIWorkflowResultUpsert) { - s.ClearStatus() - }) -} - // SetStartTime sets the "start_time" field. func (u *CIWorkflowResultUpsertBulk) SetStartTime(v int64) *CIWorkflowResultUpsertBulk { return u.Update(func(s *CIWorkflowResultUpsert) { @@ -1241,6 +1620,188 @@ func (u *CIWorkflowResultUpsertBulk) ClearEndTime() *CIWorkflowResultUpsertBulk }) } +// SetPythonVersion sets the "python_version" field. +func (u *CIWorkflowResultUpsertBulk) SetPythonVersion(v string) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetPythonVersion(v) + }) +} + +// UpdatePythonVersion sets the "python_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdatePythonVersion() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdatePythonVersion() + }) +} + +// ClearPythonVersion clears the value of the "python_version" field. +func (u *CIWorkflowResultUpsertBulk) ClearPythonVersion() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearPythonVersion() + }) +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (u *CIWorkflowResultUpsertBulk) SetPytorchVersion(v string) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetPytorchVersion(v) + }) +} + +// UpdatePytorchVersion sets the "pytorch_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdatePytorchVersion() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdatePytorchVersion() + }) +} + +// ClearPytorchVersion clears the value of the "pytorch_version" field. +func (u *CIWorkflowResultUpsertBulk) ClearPytorchVersion() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearPytorchVersion() + }) +} + +// SetCudaVersion sets the "cuda_version" field. +func (u *CIWorkflowResultUpsertBulk) SetCudaVersion(v string) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetCudaVersion(v) + }) +} + +// UpdateCudaVersion sets the "cuda_version" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdateCudaVersion() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateCudaVersion() + }) +} + +// ClearCudaVersion clears the value of the "cuda_version" field. +func (u *CIWorkflowResultUpsertBulk) ClearCudaVersion() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearCudaVersion() + }) +} + +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (u *CIWorkflowResultUpsertBulk) SetComfyRunFlags(v string) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetComfyRunFlags(v) + }) +} + +// UpdateComfyRunFlags sets the "comfy_run_flags" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdateComfyRunFlags() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateComfyRunFlags() + }) +} + +// ClearComfyRunFlags clears the value of the "comfy_run_flags" field. +func (u *CIWorkflowResultUpsertBulk) ClearComfyRunFlags() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearComfyRunFlags() + }) +} + +// SetAvgVram sets the "avg_vram" field. +func (u *CIWorkflowResultUpsertBulk) SetAvgVram(v int) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetAvgVram(v) + }) +} + +// AddAvgVram adds v to the "avg_vram" field. +func (u *CIWorkflowResultUpsertBulk) AddAvgVram(v int) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.AddAvgVram(v) + }) +} + +// UpdateAvgVram sets the "avg_vram" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdateAvgVram() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateAvgVram() + }) +} + +// ClearAvgVram clears the value of the "avg_vram" field. +func (u *CIWorkflowResultUpsertBulk) ClearAvgVram() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearAvgVram() + }) +} + +// SetPeakVram sets the "peak_vram" field. +func (u *CIWorkflowResultUpsertBulk) SetPeakVram(v int) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetPeakVram(v) + }) +} + +// AddPeakVram adds v to the "peak_vram" field. +func (u *CIWorkflowResultUpsertBulk) AddPeakVram(v int) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.AddPeakVram(v) + }) +} + +// UpdatePeakVram sets the "peak_vram" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdatePeakVram() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdatePeakVram() + }) +} + +// ClearPeakVram clears the value of the "peak_vram" field. +func (u *CIWorkflowResultUpsertBulk) ClearPeakVram() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearPeakVram() + }) +} + +// SetJobTriggerUser sets the "job_trigger_user" field. +func (u *CIWorkflowResultUpsertBulk) SetJobTriggerUser(v string) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetJobTriggerUser(v) + }) +} + +// UpdateJobTriggerUser sets the "job_trigger_user" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdateJobTriggerUser() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateJobTriggerUser() + }) +} + +// ClearJobTriggerUser clears the value of the "job_trigger_user" field. +func (u *CIWorkflowResultUpsertBulk) ClearJobTriggerUser() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearJobTriggerUser() + }) +} + +// SetMetadata sets the "metadata" field. +func (u *CIWorkflowResultUpsertBulk) SetMetadata(v map[string]interface{}) *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.SetMetadata(v) + }) +} + +// UpdateMetadata sets the "metadata" field to the value that was provided on create. +func (u *CIWorkflowResultUpsertBulk) UpdateMetadata() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.UpdateMetadata() + }) +} + +// ClearMetadata clears the value of the "metadata" field. +func (u *CIWorkflowResultUpsertBulk) ClearMetadata() *CIWorkflowResultUpsertBulk { + return u.Update(func(s *CIWorkflowResultUpsert) { + s.ClearMetadata() + }) +} + // Exec executes the query. func (u *CIWorkflowResultUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/ent/ciworkflowresult_query.go b/ent/ciworkflowresult_query.go index dd4b8b0..5b438d3 100644 --- a/ent/ciworkflowresult_query.go +++ b/ent/ciworkflowresult_query.go @@ -4,6 +4,7 @@ package ent import ( "context" + "database/sql/driver" "fmt" "math" "registry-backend/ent/ciworkflowresult" @@ -101,7 +102,7 @@ func (cwrq *CIWorkflowResultQuery) QueryStorageFile() *StorageFileQuery { step := sqlgraph.NewStep( sqlgraph.From(ciworkflowresult.Table, ciworkflowresult.FieldID, selector), sqlgraph.To(storagefile.Table, storagefile.FieldID), - sqlgraph.Edge(sqlgraph.M2O, false, ciworkflowresult.StorageFileTable, ciworkflowresult.StorageFileColumn), + sqlgraph.Edge(sqlgraph.O2M, false, ciworkflowresult.StorageFileTable, ciworkflowresult.StorageFileColumn), ) fromU = sqlgraph.SetNeighbors(cwrq.driver.Dialect(), step) return fromU, nil @@ -415,7 +416,7 @@ func (cwrq *CIWorkflowResultQuery) sqlAll(ctx context.Context, hooks ...queryHoo cwrq.withStorageFile != nil, } ) - if cwrq.withGitcommit != nil || cwrq.withStorageFile != nil { + if cwrq.withGitcommit != nil { withFKs = true } if withFKs { @@ -449,8 +450,9 @@ func (cwrq *CIWorkflowResultQuery) sqlAll(ctx context.Context, hooks ...queryHoo } } if query := cwrq.withStorageFile; query != nil { - if err := cwrq.loadStorageFile(ctx, query, nodes, nil, - func(n *CIWorkflowResult, e *StorageFile) { n.Edges.StorageFile = e }); err != nil { + if err := cwrq.loadStorageFile(ctx, query, nodes, + func(n *CIWorkflowResult) { n.Edges.StorageFile = []*StorageFile{} }, + func(n *CIWorkflowResult, e *StorageFile) { n.Edges.StorageFile = append(n.Edges.StorageFile, e) }); err != nil { return nil, err } } @@ -490,34 +492,33 @@ func (cwrq *CIWorkflowResultQuery) loadGitcommit(ctx context.Context, query *Git return nil } func (cwrq *CIWorkflowResultQuery) loadStorageFile(ctx context.Context, query *StorageFileQuery, nodes []*CIWorkflowResult, init func(*CIWorkflowResult), assign func(*CIWorkflowResult, *StorageFile)) error { - ids := make([]uuid.UUID, 0, len(nodes)) - nodeids := make(map[uuid.UUID][]*CIWorkflowResult) + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[uuid.UUID]*CIWorkflowResult) for i := range nodes { - if nodes[i].ci_workflow_result_storage_file == nil { - continue - } - fk := *nodes[i].ci_workflow_result_storage_file - if _, ok := nodeids[fk]; !ok { - ids = append(ids, fk) + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) } - nodeids[fk] = append(nodeids[fk], nodes[i]) - } - if len(ids) == 0 { - return nil } - query.Where(storagefile.IDIn(ids...)) + query.withFKs = true + query.Where(predicate.StorageFile(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(ciworkflowresult.StorageFileColumn), fks...)) + })) neighbors, err := query.All(ctx) if err != nil { return err } for _, n := range neighbors { - nodes, ok := nodeids[n.ID] - if !ok { - return fmt.Errorf(`unexpected foreign-key "ci_workflow_result_storage_file" returned %v`, n.ID) + fk := n.ci_workflow_result_storage_file + if fk == nil { + return fmt.Errorf(`foreign-key "ci_workflow_result_storage_file" is nil for node %v`, n.ID) } - for i := range nodes { - assign(nodes[i], n) + node, ok := nodeids[*fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "ci_workflow_result_storage_file" returned %v for node %v`, *fk, n.ID) } + assign(node, n) } return nil } @@ -635,6 +636,12 @@ func (cwrq *CIWorkflowResultQuery) ForShare(opts ...sql.LockOption) *CIWorkflowR return cwrq } +// Modify adds a query modifier for attaching custom logic to queries. +func (cwrq *CIWorkflowResultQuery) Modify(modifiers ...func(s *sql.Selector)) *CIWorkflowResultSelect { + cwrq.modifiers = append(cwrq.modifiers, modifiers...) + return cwrq.Select() +} + // CIWorkflowResultGroupBy is the group-by builder for CIWorkflowResult entities. type CIWorkflowResultGroupBy struct { selector @@ -724,3 +731,9 @@ func (cwrs *CIWorkflowResultSelect) sqlScan(ctx context.Context, root *CIWorkflo defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (cwrs *CIWorkflowResultSelect) Modify(modifiers ...func(s *sql.Selector)) *CIWorkflowResultSelect { + cwrs.modifiers = append(cwrs.modifiers, modifiers...) + return cwrs +} diff --git a/ent/ciworkflowresult_update.go b/ent/ciworkflowresult_update.go index 29d963f..da88a5a 100644 --- a/ent/ciworkflowresult_update.go +++ b/ent/ciworkflowresult_update.go @@ -9,6 +9,7 @@ import ( "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" "registry-backend/ent/predicate" + "registry-backend/ent/schema" "registry-backend/ent/storagefile" "time" @@ -21,8 +22,9 @@ import ( // CIWorkflowResultUpdate is the builder for updating CIWorkflowResult entities. type CIWorkflowResultUpdate struct { config - hooks []Hook - mutation *CIWorkflowResultMutation + hooks []Hook + mutation *CIWorkflowResultMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the CIWorkflowResultUpdate builder. @@ -51,46 +53,6 @@ func (cwru *CIWorkflowResultUpdate) SetNillableOperatingSystem(s *string) *CIWor return cwru } -// SetGpuType sets the "gpu_type" field. -func (cwru *CIWorkflowResultUpdate) SetGpuType(s string) *CIWorkflowResultUpdate { - cwru.mutation.SetGpuType(s) - return cwru -} - -// SetNillableGpuType sets the "gpu_type" field if the given value is not nil. -func (cwru *CIWorkflowResultUpdate) SetNillableGpuType(s *string) *CIWorkflowResultUpdate { - if s != nil { - cwru.SetGpuType(*s) - } - return cwru -} - -// ClearGpuType clears the value of the "gpu_type" field. -func (cwru *CIWorkflowResultUpdate) ClearGpuType() *CIWorkflowResultUpdate { - cwru.mutation.ClearGpuType() - return cwru -} - -// SetPytorchVersion sets the "pytorch_version" field. -func (cwru *CIWorkflowResultUpdate) SetPytorchVersion(s string) *CIWorkflowResultUpdate { - cwru.mutation.SetPytorchVersion(s) - return cwru -} - -// SetNillablePytorchVersion sets the "pytorch_version" field if the given value is not nil. -func (cwru *CIWorkflowResultUpdate) SetNillablePytorchVersion(s *string) *CIWorkflowResultUpdate { - if s != nil { - cwru.SetPytorchVersion(*s) - } - return cwru -} - -// ClearPytorchVersion clears the value of the "pytorch_version" field. -func (cwru *CIWorkflowResultUpdate) ClearPytorchVersion() *CIWorkflowResultUpdate { - cwru.mutation.ClearPytorchVersion() - return cwru -} - // SetWorkflowName sets the "workflow_name" field. func (cwru *CIWorkflowResultUpdate) SetWorkflowName(s string) *CIWorkflowResultUpdate { cwru.mutation.SetWorkflowName(s) @@ -131,23 +93,37 @@ func (cwru *CIWorkflowResultUpdate) ClearRunID() *CIWorkflowResultUpdate { return cwru } -// SetStatus sets the "status" field. -func (cwru *CIWorkflowResultUpdate) SetStatus(s string) *CIWorkflowResultUpdate { - cwru.mutation.SetStatus(s) +// SetJobID sets the "job_id" field. +func (cwru *CIWorkflowResultUpdate) SetJobID(s string) *CIWorkflowResultUpdate { + cwru.mutation.SetJobID(s) return cwru } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (cwru *CIWorkflowResultUpdate) SetNillableStatus(s *string) *CIWorkflowResultUpdate { +// SetNillableJobID sets the "job_id" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillableJobID(s *string) *CIWorkflowResultUpdate { if s != nil { - cwru.SetStatus(*s) + cwru.SetJobID(*s) } return cwru } -// ClearStatus clears the value of the "status" field. -func (cwru *CIWorkflowResultUpdate) ClearStatus() *CIWorkflowResultUpdate { - cwru.mutation.ClearStatus() +// ClearJobID clears the value of the "job_id" field. +func (cwru *CIWorkflowResultUpdate) ClearJobID() *CIWorkflowResultUpdate { + cwru.mutation.ClearJobID() + return cwru +} + +// SetStatus sets the "status" field. +func (cwru *CIWorkflowResultUpdate) SetStatus(srst schema.WorkflowRunStatusType) *CIWorkflowResultUpdate { + cwru.mutation.SetStatus(srst) + return cwru +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillableStatus(srst *schema.WorkflowRunStatusType) *CIWorkflowResultUpdate { + if srst != nil { + cwru.SetStatus(*srst) + } return cwru } @@ -205,6 +181,172 @@ func (cwru *CIWorkflowResultUpdate) ClearEndTime() *CIWorkflowResultUpdate { return cwru } +// SetPythonVersion sets the "python_version" field. +func (cwru *CIWorkflowResultUpdate) SetPythonVersion(s string) *CIWorkflowResultUpdate { + cwru.mutation.SetPythonVersion(s) + return cwru +} + +// SetNillablePythonVersion sets the "python_version" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillablePythonVersion(s *string) *CIWorkflowResultUpdate { + if s != nil { + cwru.SetPythonVersion(*s) + } + return cwru +} + +// ClearPythonVersion clears the value of the "python_version" field. +func (cwru *CIWorkflowResultUpdate) ClearPythonVersion() *CIWorkflowResultUpdate { + cwru.mutation.ClearPythonVersion() + return cwru +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (cwru *CIWorkflowResultUpdate) SetPytorchVersion(s string) *CIWorkflowResultUpdate { + cwru.mutation.SetPytorchVersion(s) + return cwru +} + +// SetNillablePytorchVersion sets the "pytorch_version" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillablePytorchVersion(s *string) *CIWorkflowResultUpdate { + if s != nil { + cwru.SetPytorchVersion(*s) + } + return cwru +} + +// ClearPytorchVersion clears the value of the "pytorch_version" field. +func (cwru *CIWorkflowResultUpdate) ClearPytorchVersion() *CIWorkflowResultUpdate { + cwru.mutation.ClearPytorchVersion() + return cwru +} + +// SetCudaVersion sets the "cuda_version" field. +func (cwru *CIWorkflowResultUpdate) SetCudaVersion(s string) *CIWorkflowResultUpdate { + cwru.mutation.SetCudaVersion(s) + return cwru +} + +// SetNillableCudaVersion sets the "cuda_version" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillableCudaVersion(s *string) *CIWorkflowResultUpdate { + if s != nil { + cwru.SetCudaVersion(*s) + } + return cwru +} + +// ClearCudaVersion clears the value of the "cuda_version" field. +func (cwru *CIWorkflowResultUpdate) ClearCudaVersion() *CIWorkflowResultUpdate { + cwru.mutation.ClearCudaVersion() + return cwru +} + +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (cwru *CIWorkflowResultUpdate) SetComfyRunFlags(s string) *CIWorkflowResultUpdate { + cwru.mutation.SetComfyRunFlags(s) + return cwru +} + +// SetNillableComfyRunFlags sets the "comfy_run_flags" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillableComfyRunFlags(s *string) *CIWorkflowResultUpdate { + if s != nil { + cwru.SetComfyRunFlags(*s) + } + return cwru +} + +// ClearComfyRunFlags clears the value of the "comfy_run_flags" field. +func (cwru *CIWorkflowResultUpdate) ClearComfyRunFlags() *CIWorkflowResultUpdate { + cwru.mutation.ClearComfyRunFlags() + return cwru +} + +// SetAvgVram sets the "avg_vram" field. +func (cwru *CIWorkflowResultUpdate) SetAvgVram(i int) *CIWorkflowResultUpdate { + cwru.mutation.ResetAvgVram() + cwru.mutation.SetAvgVram(i) + return cwru +} + +// SetNillableAvgVram sets the "avg_vram" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillableAvgVram(i *int) *CIWorkflowResultUpdate { + if i != nil { + cwru.SetAvgVram(*i) + } + return cwru +} + +// AddAvgVram adds i to the "avg_vram" field. +func (cwru *CIWorkflowResultUpdate) AddAvgVram(i int) *CIWorkflowResultUpdate { + cwru.mutation.AddAvgVram(i) + return cwru +} + +// ClearAvgVram clears the value of the "avg_vram" field. +func (cwru *CIWorkflowResultUpdate) ClearAvgVram() *CIWorkflowResultUpdate { + cwru.mutation.ClearAvgVram() + return cwru +} + +// SetPeakVram sets the "peak_vram" field. +func (cwru *CIWorkflowResultUpdate) SetPeakVram(i int) *CIWorkflowResultUpdate { + cwru.mutation.ResetPeakVram() + cwru.mutation.SetPeakVram(i) + return cwru +} + +// SetNillablePeakVram sets the "peak_vram" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillablePeakVram(i *int) *CIWorkflowResultUpdate { + if i != nil { + cwru.SetPeakVram(*i) + } + return cwru +} + +// AddPeakVram adds i to the "peak_vram" field. +func (cwru *CIWorkflowResultUpdate) AddPeakVram(i int) *CIWorkflowResultUpdate { + cwru.mutation.AddPeakVram(i) + return cwru +} + +// ClearPeakVram clears the value of the "peak_vram" field. +func (cwru *CIWorkflowResultUpdate) ClearPeakVram() *CIWorkflowResultUpdate { + cwru.mutation.ClearPeakVram() + return cwru +} + +// SetJobTriggerUser sets the "job_trigger_user" field. +func (cwru *CIWorkflowResultUpdate) SetJobTriggerUser(s string) *CIWorkflowResultUpdate { + cwru.mutation.SetJobTriggerUser(s) + return cwru +} + +// SetNillableJobTriggerUser sets the "job_trigger_user" field if the given value is not nil. +func (cwru *CIWorkflowResultUpdate) SetNillableJobTriggerUser(s *string) *CIWorkflowResultUpdate { + if s != nil { + cwru.SetJobTriggerUser(*s) + } + return cwru +} + +// ClearJobTriggerUser clears the value of the "job_trigger_user" field. +func (cwru *CIWorkflowResultUpdate) ClearJobTriggerUser() *CIWorkflowResultUpdate { + cwru.mutation.ClearJobTriggerUser() + return cwru +} + +// SetMetadata sets the "metadata" field. +func (cwru *CIWorkflowResultUpdate) SetMetadata(m map[string]interface{}) *CIWorkflowResultUpdate { + cwru.mutation.SetMetadata(m) + return cwru +} + +// ClearMetadata clears the value of the "metadata" field. +func (cwru *CIWorkflowResultUpdate) ClearMetadata() *CIWorkflowResultUpdate { + cwru.mutation.ClearMetadata() + return cwru +} + // SetGitcommitID sets the "gitcommit" edge to the GitCommit entity by ID. func (cwru *CIWorkflowResultUpdate) SetGitcommitID(id uuid.UUID) *CIWorkflowResultUpdate { cwru.mutation.SetGitcommitID(id) @@ -224,23 +366,19 @@ func (cwru *CIWorkflowResultUpdate) SetGitcommit(g *GitCommit) *CIWorkflowResult return cwru.SetGitcommitID(g.ID) } -// SetStorageFileID sets the "storage_file" edge to the StorageFile entity by ID. -func (cwru *CIWorkflowResultUpdate) SetStorageFileID(id uuid.UUID) *CIWorkflowResultUpdate { - cwru.mutation.SetStorageFileID(id) +// AddStorageFileIDs adds the "storage_file" edge to the StorageFile entity by IDs. +func (cwru *CIWorkflowResultUpdate) AddStorageFileIDs(ids ...uuid.UUID) *CIWorkflowResultUpdate { + cwru.mutation.AddStorageFileIDs(ids...) return cwru } -// SetNillableStorageFileID sets the "storage_file" edge to the StorageFile entity by ID if the given value is not nil. -func (cwru *CIWorkflowResultUpdate) SetNillableStorageFileID(id *uuid.UUID) *CIWorkflowResultUpdate { - if id != nil { - cwru = cwru.SetStorageFileID(*id) +// AddStorageFile adds the "storage_file" edges to the StorageFile entity. +func (cwru *CIWorkflowResultUpdate) AddStorageFile(s ...*StorageFile) *CIWorkflowResultUpdate { + ids := make([]uuid.UUID, len(s)) + for i := range s { + ids[i] = s[i].ID } - return cwru -} - -// SetStorageFile sets the "storage_file" edge to the StorageFile entity. -func (cwru *CIWorkflowResultUpdate) SetStorageFile(s *StorageFile) *CIWorkflowResultUpdate { - return cwru.SetStorageFileID(s.ID) + return cwru.AddStorageFileIDs(ids...) } // Mutation returns the CIWorkflowResultMutation object of the builder. @@ -254,12 +392,27 @@ func (cwru *CIWorkflowResultUpdate) ClearGitcommit() *CIWorkflowResultUpdate { return cwru } -// ClearStorageFile clears the "storage_file" edge to the StorageFile entity. +// ClearStorageFile clears all "storage_file" edges to the StorageFile entity. func (cwru *CIWorkflowResultUpdate) ClearStorageFile() *CIWorkflowResultUpdate { cwru.mutation.ClearStorageFile() return cwru } +// RemoveStorageFileIDs removes the "storage_file" edge to StorageFile entities by IDs. +func (cwru *CIWorkflowResultUpdate) RemoveStorageFileIDs(ids ...uuid.UUID) *CIWorkflowResultUpdate { + cwru.mutation.RemoveStorageFileIDs(ids...) + return cwru +} + +// RemoveStorageFile removes "storage_file" edges to StorageFile entities. +func (cwru *CIWorkflowResultUpdate) RemoveStorageFile(s ...*StorageFile) *CIWorkflowResultUpdate { + ids := make([]uuid.UUID, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return cwru.RemoveStorageFileIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (cwru *CIWorkflowResultUpdate) Save(ctx context.Context) (int, error) { cwru.defaults() @@ -296,6 +449,12 @@ func (cwru *CIWorkflowResultUpdate) defaults() { } } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (cwru *CIWorkflowResultUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *CIWorkflowResultUpdate { + cwru.modifiers = append(cwru.modifiers, modifiers...) + return cwru +} + func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err error) { _spec := sqlgraph.NewUpdateSpec(ciworkflowresult.Table, ciworkflowresult.Columns, sqlgraph.NewFieldSpec(ciworkflowresult.FieldID, field.TypeUUID)) if ps := cwru.mutation.predicates; len(ps) > 0 { @@ -311,18 +470,6 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err if value, ok := cwru.mutation.OperatingSystem(); ok { _spec.SetField(ciworkflowresult.FieldOperatingSystem, field.TypeString, value) } - if value, ok := cwru.mutation.GpuType(); ok { - _spec.SetField(ciworkflowresult.FieldGpuType, field.TypeString, value) - } - if cwru.mutation.GpuTypeCleared() { - _spec.ClearField(ciworkflowresult.FieldGpuType, field.TypeString) - } - if value, ok := cwru.mutation.PytorchVersion(); ok { - _spec.SetField(ciworkflowresult.FieldPytorchVersion, field.TypeString, value) - } - if cwru.mutation.PytorchVersionCleared() { - _spec.ClearField(ciworkflowresult.FieldPytorchVersion, field.TypeString) - } if value, ok := cwru.mutation.WorkflowName(); ok { _spec.SetField(ciworkflowresult.FieldWorkflowName, field.TypeString, value) } @@ -335,12 +482,15 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err if cwru.mutation.RunIDCleared() { _spec.ClearField(ciworkflowresult.FieldRunID, field.TypeString) } + if value, ok := cwru.mutation.JobID(); ok { + _spec.SetField(ciworkflowresult.FieldJobID, field.TypeString, value) + } + if cwru.mutation.JobIDCleared() { + _spec.ClearField(ciworkflowresult.FieldJobID, field.TypeString) + } if value, ok := cwru.mutation.Status(); ok { _spec.SetField(ciworkflowresult.FieldStatus, field.TypeString, value) } - if cwru.mutation.StatusCleared() { - _spec.ClearField(ciworkflowresult.FieldStatus, field.TypeString) - } if value, ok := cwru.mutation.StartTime(); ok { _spec.SetField(ciworkflowresult.FieldStartTime, field.TypeInt64, value) } @@ -359,6 +509,60 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err if cwru.mutation.EndTimeCleared() { _spec.ClearField(ciworkflowresult.FieldEndTime, field.TypeInt64) } + if value, ok := cwru.mutation.PythonVersion(); ok { + _spec.SetField(ciworkflowresult.FieldPythonVersion, field.TypeString, value) + } + if cwru.mutation.PythonVersionCleared() { + _spec.ClearField(ciworkflowresult.FieldPythonVersion, field.TypeString) + } + if value, ok := cwru.mutation.PytorchVersion(); ok { + _spec.SetField(ciworkflowresult.FieldPytorchVersion, field.TypeString, value) + } + if cwru.mutation.PytorchVersionCleared() { + _spec.ClearField(ciworkflowresult.FieldPytorchVersion, field.TypeString) + } + if value, ok := cwru.mutation.CudaVersion(); ok { + _spec.SetField(ciworkflowresult.FieldCudaVersion, field.TypeString, value) + } + if cwru.mutation.CudaVersionCleared() { + _spec.ClearField(ciworkflowresult.FieldCudaVersion, field.TypeString) + } + if value, ok := cwru.mutation.ComfyRunFlags(); ok { + _spec.SetField(ciworkflowresult.FieldComfyRunFlags, field.TypeString, value) + } + if cwru.mutation.ComfyRunFlagsCleared() { + _spec.ClearField(ciworkflowresult.FieldComfyRunFlags, field.TypeString) + } + if value, ok := cwru.mutation.AvgVram(); ok { + _spec.SetField(ciworkflowresult.FieldAvgVram, field.TypeInt, value) + } + if value, ok := cwru.mutation.AddedAvgVram(); ok { + _spec.AddField(ciworkflowresult.FieldAvgVram, field.TypeInt, value) + } + if cwru.mutation.AvgVramCleared() { + _spec.ClearField(ciworkflowresult.FieldAvgVram, field.TypeInt) + } + if value, ok := cwru.mutation.PeakVram(); ok { + _spec.SetField(ciworkflowresult.FieldPeakVram, field.TypeInt, value) + } + if value, ok := cwru.mutation.AddedPeakVram(); ok { + _spec.AddField(ciworkflowresult.FieldPeakVram, field.TypeInt, value) + } + if cwru.mutation.PeakVramCleared() { + _spec.ClearField(ciworkflowresult.FieldPeakVram, field.TypeInt) + } + if value, ok := cwru.mutation.JobTriggerUser(); ok { + _spec.SetField(ciworkflowresult.FieldJobTriggerUser, field.TypeString, value) + } + if cwru.mutation.JobTriggerUserCleared() { + _spec.ClearField(ciworkflowresult.FieldJobTriggerUser, field.TypeString) + } + if value, ok := cwru.mutation.Metadata(); ok { + _spec.SetField(ciworkflowresult.FieldMetadata, field.TypeJSON, value) + } + if cwru.mutation.MetadataCleared() { + _spec.ClearField(ciworkflowresult.FieldMetadata, field.TypeJSON) + } if cwru.mutation.GitcommitCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -390,7 +594,7 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err } if cwru.mutation.StorageFileCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.O2M, Inverse: false, Table: ciworkflowresult.StorageFileTable, Columns: []string{ciworkflowresult.StorageFileColumn}, @@ -401,9 +605,25 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) } + if nodes := cwru.mutation.RemovedStorageFileIDs(); len(nodes) > 0 && !cwru.mutation.StorageFileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: ciworkflowresult.StorageFileTable, + Columns: []string{ciworkflowresult.StorageFileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagefile.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } if nodes := cwru.mutation.StorageFileIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.O2M, Inverse: false, Table: ciworkflowresult.StorageFileTable, Columns: []string{ciworkflowresult.StorageFileColumn}, @@ -417,6 +637,7 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(cwru.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, cwru.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{ciworkflowresult.Label} @@ -432,9 +653,10 @@ func (cwru *CIWorkflowResultUpdate) sqlSave(ctx context.Context) (n int, err err // CIWorkflowResultUpdateOne is the builder for updating a single CIWorkflowResult entity. type CIWorkflowResultUpdateOne struct { config - fields []string - hooks []Hook - mutation *CIWorkflowResultMutation + fields []string + hooks []Hook + mutation *CIWorkflowResultMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -457,46 +679,6 @@ func (cwruo *CIWorkflowResultUpdateOne) SetNillableOperatingSystem(s *string) *C return cwruo } -// SetGpuType sets the "gpu_type" field. -func (cwruo *CIWorkflowResultUpdateOne) SetGpuType(s string) *CIWorkflowResultUpdateOne { - cwruo.mutation.SetGpuType(s) - return cwruo -} - -// SetNillableGpuType sets the "gpu_type" field if the given value is not nil. -func (cwruo *CIWorkflowResultUpdateOne) SetNillableGpuType(s *string) *CIWorkflowResultUpdateOne { - if s != nil { - cwruo.SetGpuType(*s) - } - return cwruo -} - -// ClearGpuType clears the value of the "gpu_type" field. -func (cwruo *CIWorkflowResultUpdateOne) ClearGpuType() *CIWorkflowResultUpdateOne { - cwruo.mutation.ClearGpuType() - return cwruo -} - -// SetPytorchVersion sets the "pytorch_version" field. -func (cwruo *CIWorkflowResultUpdateOne) SetPytorchVersion(s string) *CIWorkflowResultUpdateOne { - cwruo.mutation.SetPytorchVersion(s) - return cwruo -} - -// SetNillablePytorchVersion sets the "pytorch_version" field if the given value is not nil. -func (cwruo *CIWorkflowResultUpdateOne) SetNillablePytorchVersion(s *string) *CIWorkflowResultUpdateOne { - if s != nil { - cwruo.SetPytorchVersion(*s) - } - return cwruo -} - -// ClearPytorchVersion clears the value of the "pytorch_version" field. -func (cwruo *CIWorkflowResultUpdateOne) ClearPytorchVersion() *CIWorkflowResultUpdateOne { - cwruo.mutation.ClearPytorchVersion() - return cwruo -} - // SetWorkflowName sets the "workflow_name" field. func (cwruo *CIWorkflowResultUpdateOne) SetWorkflowName(s string) *CIWorkflowResultUpdateOne { cwruo.mutation.SetWorkflowName(s) @@ -537,23 +719,37 @@ func (cwruo *CIWorkflowResultUpdateOne) ClearRunID() *CIWorkflowResultUpdateOne return cwruo } -// SetStatus sets the "status" field. -func (cwruo *CIWorkflowResultUpdateOne) SetStatus(s string) *CIWorkflowResultUpdateOne { - cwruo.mutation.SetStatus(s) +// SetJobID sets the "job_id" field. +func (cwruo *CIWorkflowResultUpdateOne) SetJobID(s string) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetJobID(s) return cwruo } -// SetNillableStatus sets the "status" field if the given value is not nil. -func (cwruo *CIWorkflowResultUpdateOne) SetNillableStatus(s *string) *CIWorkflowResultUpdateOne { +// SetNillableJobID sets the "job_id" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillableJobID(s *string) *CIWorkflowResultUpdateOne { if s != nil { - cwruo.SetStatus(*s) + cwruo.SetJobID(*s) } return cwruo } -// ClearStatus clears the value of the "status" field. -func (cwruo *CIWorkflowResultUpdateOne) ClearStatus() *CIWorkflowResultUpdateOne { - cwruo.mutation.ClearStatus() +// ClearJobID clears the value of the "job_id" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearJobID() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearJobID() + return cwruo +} + +// SetStatus sets the "status" field. +func (cwruo *CIWorkflowResultUpdateOne) SetStatus(srst schema.WorkflowRunStatusType) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetStatus(srst) + return cwruo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillableStatus(srst *schema.WorkflowRunStatusType) *CIWorkflowResultUpdateOne { + if srst != nil { + cwruo.SetStatus(*srst) + } return cwruo } @@ -611,6 +807,172 @@ func (cwruo *CIWorkflowResultUpdateOne) ClearEndTime() *CIWorkflowResultUpdateOn return cwruo } +// SetPythonVersion sets the "python_version" field. +func (cwruo *CIWorkflowResultUpdateOne) SetPythonVersion(s string) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetPythonVersion(s) + return cwruo +} + +// SetNillablePythonVersion sets the "python_version" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillablePythonVersion(s *string) *CIWorkflowResultUpdateOne { + if s != nil { + cwruo.SetPythonVersion(*s) + } + return cwruo +} + +// ClearPythonVersion clears the value of the "python_version" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearPythonVersion() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearPythonVersion() + return cwruo +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (cwruo *CIWorkflowResultUpdateOne) SetPytorchVersion(s string) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetPytorchVersion(s) + return cwruo +} + +// SetNillablePytorchVersion sets the "pytorch_version" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillablePytorchVersion(s *string) *CIWorkflowResultUpdateOne { + if s != nil { + cwruo.SetPytorchVersion(*s) + } + return cwruo +} + +// ClearPytorchVersion clears the value of the "pytorch_version" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearPytorchVersion() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearPytorchVersion() + return cwruo +} + +// SetCudaVersion sets the "cuda_version" field. +func (cwruo *CIWorkflowResultUpdateOne) SetCudaVersion(s string) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetCudaVersion(s) + return cwruo +} + +// SetNillableCudaVersion sets the "cuda_version" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillableCudaVersion(s *string) *CIWorkflowResultUpdateOne { + if s != nil { + cwruo.SetCudaVersion(*s) + } + return cwruo +} + +// ClearCudaVersion clears the value of the "cuda_version" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearCudaVersion() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearCudaVersion() + return cwruo +} + +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (cwruo *CIWorkflowResultUpdateOne) SetComfyRunFlags(s string) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetComfyRunFlags(s) + return cwruo +} + +// SetNillableComfyRunFlags sets the "comfy_run_flags" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillableComfyRunFlags(s *string) *CIWorkflowResultUpdateOne { + if s != nil { + cwruo.SetComfyRunFlags(*s) + } + return cwruo +} + +// ClearComfyRunFlags clears the value of the "comfy_run_flags" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearComfyRunFlags() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearComfyRunFlags() + return cwruo +} + +// SetAvgVram sets the "avg_vram" field. +func (cwruo *CIWorkflowResultUpdateOne) SetAvgVram(i int) *CIWorkflowResultUpdateOne { + cwruo.mutation.ResetAvgVram() + cwruo.mutation.SetAvgVram(i) + return cwruo +} + +// SetNillableAvgVram sets the "avg_vram" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillableAvgVram(i *int) *CIWorkflowResultUpdateOne { + if i != nil { + cwruo.SetAvgVram(*i) + } + return cwruo +} + +// AddAvgVram adds i to the "avg_vram" field. +func (cwruo *CIWorkflowResultUpdateOne) AddAvgVram(i int) *CIWorkflowResultUpdateOne { + cwruo.mutation.AddAvgVram(i) + return cwruo +} + +// ClearAvgVram clears the value of the "avg_vram" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearAvgVram() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearAvgVram() + return cwruo +} + +// SetPeakVram sets the "peak_vram" field. +func (cwruo *CIWorkflowResultUpdateOne) SetPeakVram(i int) *CIWorkflowResultUpdateOne { + cwruo.mutation.ResetPeakVram() + cwruo.mutation.SetPeakVram(i) + return cwruo +} + +// SetNillablePeakVram sets the "peak_vram" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillablePeakVram(i *int) *CIWorkflowResultUpdateOne { + if i != nil { + cwruo.SetPeakVram(*i) + } + return cwruo +} + +// AddPeakVram adds i to the "peak_vram" field. +func (cwruo *CIWorkflowResultUpdateOne) AddPeakVram(i int) *CIWorkflowResultUpdateOne { + cwruo.mutation.AddPeakVram(i) + return cwruo +} + +// ClearPeakVram clears the value of the "peak_vram" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearPeakVram() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearPeakVram() + return cwruo +} + +// SetJobTriggerUser sets the "job_trigger_user" field. +func (cwruo *CIWorkflowResultUpdateOne) SetJobTriggerUser(s string) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetJobTriggerUser(s) + return cwruo +} + +// SetNillableJobTriggerUser sets the "job_trigger_user" field if the given value is not nil. +func (cwruo *CIWorkflowResultUpdateOne) SetNillableJobTriggerUser(s *string) *CIWorkflowResultUpdateOne { + if s != nil { + cwruo.SetJobTriggerUser(*s) + } + return cwruo +} + +// ClearJobTriggerUser clears the value of the "job_trigger_user" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearJobTriggerUser() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearJobTriggerUser() + return cwruo +} + +// SetMetadata sets the "metadata" field. +func (cwruo *CIWorkflowResultUpdateOne) SetMetadata(m map[string]interface{}) *CIWorkflowResultUpdateOne { + cwruo.mutation.SetMetadata(m) + return cwruo +} + +// ClearMetadata clears the value of the "metadata" field. +func (cwruo *CIWorkflowResultUpdateOne) ClearMetadata() *CIWorkflowResultUpdateOne { + cwruo.mutation.ClearMetadata() + return cwruo +} + // SetGitcommitID sets the "gitcommit" edge to the GitCommit entity by ID. func (cwruo *CIWorkflowResultUpdateOne) SetGitcommitID(id uuid.UUID) *CIWorkflowResultUpdateOne { cwruo.mutation.SetGitcommitID(id) @@ -630,23 +992,19 @@ func (cwruo *CIWorkflowResultUpdateOne) SetGitcommit(g *GitCommit) *CIWorkflowRe return cwruo.SetGitcommitID(g.ID) } -// SetStorageFileID sets the "storage_file" edge to the StorageFile entity by ID. -func (cwruo *CIWorkflowResultUpdateOne) SetStorageFileID(id uuid.UUID) *CIWorkflowResultUpdateOne { - cwruo.mutation.SetStorageFileID(id) +// AddStorageFileIDs adds the "storage_file" edge to the StorageFile entity by IDs. +func (cwruo *CIWorkflowResultUpdateOne) AddStorageFileIDs(ids ...uuid.UUID) *CIWorkflowResultUpdateOne { + cwruo.mutation.AddStorageFileIDs(ids...) return cwruo } -// SetNillableStorageFileID sets the "storage_file" edge to the StorageFile entity by ID if the given value is not nil. -func (cwruo *CIWorkflowResultUpdateOne) SetNillableStorageFileID(id *uuid.UUID) *CIWorkflowResultUpdateOne { - if id != nil { - cwruo = cwruo.SetStorageFileID(*id) +// AddStorageFile adds the "storage_file" edges to the StorageFile entity. +func (cwruo *CIWorkflowResultUpdateOne) AddStorageFile(s ...*StorageFile) *CIWorkflowResultUpdateOne { + ids := make([]uuid.UUID, len(s)) + for i := range s { + ids[i] = s[i].ID } - return cwruo -} - -// SetStorageFile sets the "storage_file" edge to the StorageFile entity. -func (cwruo *CIWorkflowResultUpdateOne) SetStorageFile(s *StorageFile) *CIWorkflowResultUpdateOne { - return cwruo.SetStorageFileID(s.ID) + return cwruo.AddStorageFileIDs(ids...) } // Mutation returns the CIWorkflowResultMutation object of the builder. @@ -660,12 +1018,27 @@ func (cwruo *CIWorkflowResultUpdateOne) ClearGitcommit() *CIWorkflowResultUpdate return cwruo } -// ClearStorageFile clears the "storage_file" edge to the StorageFile entity. +// ClearStorageFile clears all "storage_file" edges to the StorageFile entity. func (cwruo *CIWorkflowResultUpdateOne) ClearStorageFile() *CIWorkflowResultUpdateOne { cwruo.mutation.ClearStorageFile() return cwruo } +// RemoveStorageFileIDs removes the "storage_file" edge to StorageFile entities by IDs. +func (cwruo *CIWorkflowResultUpdateOne) RemoveStorageFileIDs(ids ...uuid.UUID) *CIWorkflowResultUpdateOne { + cwruo.mutation.RemoveStorageFileIDs(ids...) + return cwruo +} + +// RemoveStorageFile removes "storage_file" edges to StorageFile entities. +func (cwruo *CIWorkflowResultUpdateOne) RemoveStorageFile(s ...*StorageFile) *CIWorkflowResultUpdateOne { + ids := make([]uuid.UUID, len(s)) + for i := range s { + ids[i] = s[i].ID + } + return cwruo.RemoveStorageFileIDs(ids...) +} + // Where appends a list predicates to the CIWorkflowResultUpdate builder. func (cwruo *CIWorkflowResultUpdateOne) Where(ps ...predicate.CIWorkflowResult) *CIWorkflowResultUpdateOne { cwruo.mutation.Where(ps...) @@ -715,6 +1088,12 @@ func (cwruo *CIWorkflowResultUpdateOne) defaults() { } } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (cwruo *CIWorkflowResultUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *CIWorkflowResultUpdateOne { + cwruo.modifiers = append(cwruo.modifiers, modifiers...) + return cwruo +} + func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIWorkflowResult, err error) { _spec := sqlgraph.NewUpdateSpec(ciworkflowresult.Table, ciworkflowresult.Columns, sqlgraph.NewFieldSpec(ciworkflowresult.FieldID, field.TypeUUID)) id, ok := cwruo.mutation.ID() @@ -747,18 +1126,6 @@ func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIW if value, ok := cwruo.mutation.OperatingSystem(); ok { _spec.SetField(ciworkflowresult.FieldOperatingSystem, field.TypeString, value) } - if value, ok := cwruo.mutation.GpuType(); ok { - _spec.SetField(ciworkflowresult.FieldGpuType, field.TypeString, value) - } - if cwruo.mutation.GpuTypeCleared() { - _spec.ClearField(ciworkflowresult.FieldGpuType, field.TypeString) - } - if value, ok := cwruo.mutation.PytorchVersion(); ok { - _spec.SetField(ciworkflowresult.FieldPytorchVersion, field.TypeString, value) - } - if cwruo.mutation.PytorchVersionCleared() { - _spec.ClearField(ciworkflowresult.FieldPytorchVersion, field.TypeString) - } if value, ok := cwruo.mutation.WorkflowName(); ok { _spec.SetField(ciworkflowresult.FieldWorkflowName, field.TypeString, value) } @@ -771,12 +1138,15 @@ func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIW if cwruo.mutation.RunIDCleared() { _spec.ClearField(ciworkflowresult.FieldRunID, field.TypeString) } + if value, ok := cwruo.mutation.JobID(); ok { + _spec.SetField(ciworkflowresult.FieldJobID, field.TypeString, value) + } + if cwruo.mutation.JobIDCleared() { + _spec.ClearField(ciworkflowresult.FieldJobID, field.TypeString) + } if value, ok := cwruo.mutation.Status(); ok { _spec.SetField(ciworkflowresult.FieldStatus, field.TypeString, value) } - if cwruo.mutation.StatusCleared() { - _spec.ClearField(ciworkflowresult.FieldStatus, field.TypeString) - } if value, ok := cwruo.mutation.StartTime(); ok { _spec.SetField(ciworkflowresult.FieldStartTime, field.TypeInt64, value) } @@ -795,6 +1165,60 @@ func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIW if cwruo.mutation.EndTimeCleared() { _spec.ClearField(ciworkflowresult.FieldEndTime, field.TypeInt64) } + if value, ok := cwruo.mutation.PythonVersion(); ok { + _spec.SetField(ciworkflowresult.FieldPythonVersion, field.TypeString, value) + } + if cwruo.mutation.PythonVersionCleared() { + _spec.ClearField(ciworkflowresult.FieldPythonVersion, field.TypeString) + } + if value, ok := cwruo.mutation.PytorchVersion(); ok { + _spec.SetField(ciworkflowresult.FieldPytorchVersion, field.TypeString, value) + } + if cwruo.mutation.PytorchVersionCleared() { + _spec.ClearField(ciworkflowresult.FieldPytorchVersion, field.TypeString) + } + if value, ok := cwruo.mutation.CudaVersion(); ok { + _spec.SetField(ciworkflowresult.FieldCudaVersion, field.TypeString, value) + } + if cwruo.mutation.CudaVersionCleared() { + _spec.ClearField(ciworkflowresult.FieldCudaVersion, field.TypeString) + } + if value, ok := cwruo.mutation.ComfyRunFlags(); ok { + _spec.SetField(ciworkflowresult.FieldComfyRunFlags, field.TypeString, value) + } + if cwruo.mutation.ComfyRunFlagsCleared() { + _spec.ClearField(ciworkflowresult.FieldComfyRunFlags, field.TypeString) + } + if value, ok := cwruo.mutation.AvgVram(); ok { + _spec.SetField(ciworkflowresult.FieldAvgVram, field.TypeInt, value) + } + if value, ok := cwruo.mutation.AddedAvgVram(); ok { + _spec.AddField(ciworkflowresult.FieldAvgVram, field.TypeInt, value) + } + if cwruo.mutation.AvgVramCleared() { + _spec.ClearField(ciworkflowresult.FieldAvgVram, field.TypeInt) + } + if value, ok := cwruo.mutation.PeakVram(); ok { + _spec.SetField(ciworkflowresult.FieldPeakVram, field.TypeInt, value) + } + if value, ok := cwruo.mutation.AddedPeakVram(); ok { + _spec.AddField(ciworkflowresult.FieldPeakVram, field.TypeInt, value) + } + if cwruo.mutation.PeakVramCleared() { + _spec.ClearField(ciworkflowresult.FieldPeakVram, field.TypeInt) + } + if value, ok := cwruo.mutation.JobTriggerUser(); ok { + _spec.SetField(ciworkflowresult.FieldJobTriggerUser, field.TypeString, value) + } + if cwruo.mutation.JobTriggerUserCleared() { + _spec.ClearField(ciworkflowresult.FieldJobTriggerUser, field.TypeString) + } + if value, ok := cwruo.mutation.Metadata(); ok { + _spec.SetField(ciworkflowresult.FieldMetadata, field.TypeJSON, value) + } + if cwruo.mutation.MetadataCleared() { + _spec.ClearField(ciworkflowresult.FieldMetadata, field.TypeJSON) + } if cwruo.mutation.GitcommitCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -826,7 +1250,7 @@ func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIW } if cwruo.mutation.StorageFileCleared() { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.O2M, Inverse: false, Table: ciworkflowresult.StorageFileTable, Columns: []string{ciworkflowresult.StorageFileColumn}, @@ -837,9 +1261,25 @@ func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIW } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) } + if nodes := cwruo.mutation.RemovedStorageFileIDs(); len(nodes) > 0 && !cwruo.mutation.StorageFileCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: ciworkflowresult.StorageFileTable, + Columns: []string{ciworkflowresult.StorageFileColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(storagefile.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } if nodes := cwruo.mutation.StorageFileIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ - Rel: sqlgraph.M2O, + Rel: sqlgraph.O2M, Inverse: false, Table: ciworkflowresult.StorageFileTable, Columns: []string{ciworkflowresult.StorageFileColumn}, @@ -853,6 +1293,7 @@ func (cwruo *CIWorkflowResultUpdateOne) sqlSave(ctx context.Context) (_node *CIW } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(cwruo.modifiers...) _node = &CIWorkflowResult{config: cwruo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/client.go b/ent/client.go index 462edde..04dbce9 100644 --- a/ent/client.go +++ b/ent/client.go @@ -14,6 +14,7 @@ import ( "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/personalaccesstoken" "registry-backend/ent/publisher" @@ -39,6 +40,8 @@ type Client struct { GitCommit *GitCommitClient // Node is the client for interacting with the Node builders. Node *NodeClient + // NodeReview is the client for interacting with the NodeReview builders. + NodeReview *NodeReviewClient // NodeVersion is the client for interacting with the NodeVersion builders. NodeVersion *NodeVersionClient // PersonalAccessToken is the client for interacting with the PersonalAccessToken builders. @@ -65,6 +68,7 @@ func (c *Client) init() { c.CIWorkflowResult = NewCIWorkflowResultClient(c.config) c.GitCommit = NewGitCommitClient(c.config) c.Node = NewNodeClient(c.config) + c.NodeReview = NewNodeReviewClient(c.config) c.NodeVersion = NewNodeVersionClient(c.config) c.PersonalAccessToken = NewPersonalAccessTokenClient(c.config) c.Publisher = NewPublisherClient(c.config) @@ -166,6 +170,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { CIWorkflowResult: NewCIWorkflowResultClient(cfg), GitCommit: NewGitCommitClient(cfg), Node: NewNodeClient(cfg), + NodeReview: NewNodeReviewClient(cfg), NodeVersion: NewNodeVersionClient(cfg), PersonalAccessToken: NewPersonalAccessTokenClient(cfg), Publisher: NewPublisherClient(cfg), @@ -194,6 +199,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) CIWorkflowResult: NewCIWorkflowResultClient(cfg), GitCommit: NewGitCommitClient(cfg), Node: NewNodeClient(cfg), + NodeReview: NewNodeReviewClient(cfg), NodeVersion: NewNodeVersionClient(cfg), PersonalAccessToken: NewPersonalAccessTokenClient(cfg), Publisher: NewPublisherClient(cfg), @@ -229,8 +235,9 @@ func (c *Client) Close() error { // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ - c.CIWorkflowResult, c.GitCommit, c.Node, c.NodeVersion, c.PersonalAccessToken, - c.Publisher, c.PublisherPermission, c.StorageFile, c.User, + c.CIWorkflowResult, c.GitCommit, c.Node, c.NodeReview, c.NodeVersion, + c.PersonalAccessToken, c.Publisher, c.PublisherPermission, c.StorageFile, + c.User, } { n.Use(hooks...) } @@ -240,8 +247,9 @@ func (c *Client) Use(hooks ...Hook) { // In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ - c.CIWorkflowResult, c.GitCommit, c.Node, c.NodeVersion, c.PersonalAccessToken, - c.Publisher, c.PublisherPermission, c.StorageFile, c.User, + c.CIWorkflowResult, c.GitCommit, c.Node, c.NodeReview, c.NodeVersion, + c.PersonalAccessToken, c.Publisher, c.PublisherPermission, c.StorageFile, + c.User, } { n.Intercept(interceptors...) } @@ -256,6 +264,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.GitCommit.mutate(ctx, m) case *NodeMutation: return c.Node.mutate(ctx, m) + case *NodeReviewMutation: + return c.NodeReview.mutate(ctx, m) case *NodeVersionMutation: return c.NodeVersion.mutate(ctx, m) case *PersonalAccessTokenMutation: @@ -405,7 +415,7 @@ func (c *CIWorkflowResultClient) QueryStorageFile(cwr *CIWorkflowResult) *Storag step := sqlgraph.NewStep( sqlgraph.From(ciworkflowresult.Table, ciworkflowresult.FieldID, id), sqlgraph.To(storagefile.Table, storagefile.FieldID), - sqlgraph.Edge(sqlgraph.M2O, false, ciworkflowresult.StorageFileTable, ciworkflowresult.StorageFileColumn), + sqlgraph.Edge(sqlgraph.O2M, false, ciworkflowresult.StorageFileTable, ciworkflowresult.StorageFileColumn), ) fromV = sqlgraph.Neighbors(cwr.driver.Dialect(), step) return fromV, nil @@ -727,6 +737,22 @@ func (c *NodeClient) QueryVersions(n *Node) *NodeVersionQuery { return query } +// QueryReviews queries the reviews edge of a Node. +func (c *NodeClient) QueryReviews(n *Node) *NodeReviewQuery { + query := (&NodeReviewClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := n.ID + step := sqlgraph.NewStep( + sqlgraph.From(node.Table, node.FieldID, id), + sqlgraph.To(nodereview.Table, nodereview.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, node.ReviewsTable, node.ReviewsColumn), + ) + fromV = sqlgraph.Neighbors(n.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *NodeClient) Hooks() []Hook { return c.hooks.Node @@ -752,6 +778,171 @@ func (c *NodeClient) mutate(ctx context.Context, m *NodeMutation) (Value, error) } } +// NodeReviewClient is a client for the NodeReview schema. +type NodeReviewClient struct { + config +} + +// NewNodeReviewClient returns a client for the NodeReview from the given config. +func NewNodeReviewClient(c config) *NodeReviewClient { + return &NodeReviewClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `nodereview.Hooks(f(g(h())))`. +func (c *NodeReviewClient) Use(hooks ...Hook) { + c.hooks.NodeReview = append(c.hooks.NodeReview, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `nodereview.Intercept(f(g(h())))`. +func (c *NodeReviewClient) Intercept(interceptors ...Interceptor) { + c.inters.NodeReview = append(c.inters.NodeReview, interceptors...) +} + +// Create returns a builder for creating a NodeReview entity. +func (c *NodeReviewClient) Create() *NodeReviewCreate { + mutation := newNodeReviewMutation(c.config, OpCreate) + return &NodeReviewCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of NodeReview entities. +func (c *NodeReviewClient) CreateBulk(builders ...*NodeReviewCreate) *NodeReviewCreateBulk { + return &NodeReviewCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *NodeReviewClient) MapCreateBulk(slice any, setFunc func(*NodeReviewCreate, int)) *NodeReviewCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &NodeReviewCreateBulk{err: fmt.Errorf("calling to NodeReviewClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*NodeReviewCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &NodeReviewCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for NodeReview. +func (c *NodeReviewClient) Update() *NodeReviewUpdate { + mutation := newNodeReviewMutation(c.config, OpUpdate) + return &NodeReviewUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *NodeReviewClient) UpdateOne(nr *NodeReview) *NodeReviewUpdateOne { + mutation := newNodeReviewMutation(c.config, OpUpdateOne, withNodeReview(nr)) + return &NodeReviewUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *NodeReviewClient) UpdateOneID(id uuid.UUID) *NodeReviewUpdateOne { + mutation := newNodeReviewMutation(c.config, OpUpdateOne, withNodeReviewID(id)) + return &NodeReviewUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for NodeReview. +func (c *NodeReviewClient) Delete() *NodeReviewDelete { + mutation := newNodeReviewMutation(c.config, OpDelete) + return &NodeReviewDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *NodeReviewClient) DeleteOne(nr *NodeReview) *NodeReviewDeleteOne { + return c.DeleteOneID(nr.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *NodeReviewClient) DeleteOneID(id uuid.UUID) *NodeReviewDeleteOne { + builder := c.Delete().Where(nodereview.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &NodeReviewDeleteOne{builder} +} + +// Query returns a query builder for NodeReview. +func (c *NodeReviewClient) Query() *NodeReviewQuery { + return &NodeReviewQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeNodeReview}, + inters: c.Interceptors(), + } +} + +// Get returns a NodeReview entity by its id. +func (c *NodeReviewClient) Get(ctx context.Context, id uuid.UUID) (*NodeReview, error) { + return c.Query().Where(nodereview.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *NodeReviewClient) GetX(ctx context.Context, id uuid.UUID) *NodeReview { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// QueryUser queries the user edge of a NodeReview. +func (c *NodeReviewClient) QueryUser(nr *NodeReview) *UserQuery { + query := (&UserClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := nr.ID + step := sqlgraph.NewStep( + sqlgraph.From(nodereview.Table, nodereview.FieldID, id), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, nodereview.UserTable, nodereview.UserColumn), + ) + fromV = sqlgraph.Neighbors(nr.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// QueryNode queries the node edge of a NodeReview. +func (c *NodeReviewClient) QueryNode(nr *NodeReview) *NodeQuery { + query := (&NodeClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := nr.ID + step := sqlgraph.NewStep( + sqlgraph.From(nodereview.Table, nodereview.FieldID, id), + sqlgraph.To(node.Table, node.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, nodereview.NodeTable, nodereview.NodeColumn), + ) + fromV = sqlgraph.Neighbors(nr.driver.Dialect(), step) + return fromV, nil + } + return query +} + +// Hooks returns the client hooks. +func (c *NodeReviewClient) Hooks() []Hook { + return c.hooks.NodeReview +} + +// Interceptors returns the client interceptors. +func (c *NodeReviewClient) Interceptors() []Interceptor { + return c.inters.NodeReview +} + +func (c *NodeReviewClient) mutate(ctx context.Context, m *NodeReviewMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&NodeReviewCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&NodeReviewUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&NodeReviewUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&NodeReviewDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown NodeReview mutation op: %q", m.Op()) + } +} + // NodeVersionClient is a client for the NodeVersion schema. type NodeVersionClient struct { config @@ -1669,6 +1860,22 @@ func (c *UserClient) QueryPublisherPermissions(u *User) *PublisherPermissionQuer return query } +// QueryReviews queries the reviews edge of a User. +func (c *UserClient) QueryReviews(u *User) *NodeReviewQuery { + query := (&NodeReviewClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { + id := u.ID + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, id), + sqlgraph.To(nodereview.Table, nodereview.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.ReviewsTable, user.ReviewsColumn), + ) + fromV = sqlgraph.Neighbors(u.driver.Dialect(), step) + return fromV, nil + } + return query +} + // Hooks returns the client hooks. func (c *UserClient) Hooks() []Hook { return c.hooks.User @@ -1697,11 +1904,11 @@ func (c *UserClient) mutate(ctx context.Context, m *UserMutation) (Value, error) // hooks and interceptors per client, for fast access. type ( hooks struct { - CIWorkflowResult, GitCommit, Node, NodeVersion, PersonalAccessToken, Publisher, - PublisherPermission, StorageFile, User []ent.Hook + CIWorkflowResult, GitCommit, Node, NodeReview, NodeVersion, PersonalAccessToken, + Publisher, PublisherPermission, StorageFile, User []ent.Hook } inters struct { - CIWorkflowResult, GitCommit, Node, NodeVersion, PersonalAccessToken, Publisher, - PublisherPermission, StorageFile, User []ent.Interceptor + CIWorkflowResult, GitCommit, Node, NodeReview, NodeVersion, PersonalAccessToken, + Publisher, PublisherPermission, StorageFile, User []ent.Interceptor } ) diff --git a/ent/ent.go b/ent/ent.go index 52ab9e1..aec797c 100644 --- a/ent/ent.go +++ b/ent/ent.go @@ -10,6 +10,7 @@ import ( "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/personalaccesstoken" "registry-backend/ent/publisher" @@ -84,6 +85,7 @@ func checkColumn(table, column string) error { ciworkflowresult.Table: ciworkflowresult.ValidColumn, gitcommit.Table: gitcommit.ValidColumn, node.Table: node.ValidColumn, + nodereview.Table: nodereview.ValidColumn, nodeversion.Table: nodeversion.ValidColumn, personalaccesstoken.Table: personalaccesstoken.ValidColumn, publisher.Table: publisher.ValidColumn, diff --git a/ent/generate.go b/ent/generate.go index 10036c5..2bfde0c 100644 --- a/ent/generate.go +++ b/ent/generate.go @@ -1,3 +1,3 @@ package ent -//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert --feature sql/lock ./schema +//go:generate go run -mod=mod entgo.io/ent/cmd/ent generate --feature sql/upsert --feature sql/lock --feature sql/modifier ./schema diff --git a/ent/gitcommit.go b/ent/gitcommit.go index 1f5480e..e6bb474 100644 --- a/ent/gitcommit.go +++ b/ent/gitcommit.go @@ -36,6 +36,8 @@ type GitCommit struct { Author string `json:"author,omitempty"` // Timestamp holds the value of the "timestamp" field. Timestamp time.Time `json:"timestamp,omitempty"` + // PrNumber holds the value of the "pr_number" field. + PrNumber string `json:"pr_number,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GitCommitQuery when eager-loading is set. Edges GitCommitEdges `json:"edges"` @@ -65,7 +67,7 @@ func (*GitCommit) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case gitcommit.FieldCommitHash, gitcommit.FieldBranchName, gitcommit.FieldRepoName, gitcommit.FieldCommitMessage, gitcommit.FieldAuthor: + case gitcommit.FieldCommitHash, gitcommit.FieldBranchName, gitcommit.FieldRepoName, gitcommit.FieldCommitMessage, gitcommit.FieldAuthor, gitcommit.FieldPrNumber: values[i] = new(sql.NullString) case gitcommit.FieldCreateTime, gitcommit.FieldUpdateTime, gitcommit.FieldCommitTimestamp, gitcommit.FieldTimestamp: values[i] = new(sql.NullTime) @@ -146,6 +148,12 @@ func (gc *GitCommit) assignValues(columns []string, values []any) error { } else if value.Valid { gc.Timestamp = value.Time } + case gitcommit.FieldPrNumber: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field pr_number", values[i]) + } else if value.Valid { + gc.PrNumber = value.String + } default: gc.selectValues.Set(columns[i], values[i]) } @@ -213,6 +221,9 @@ func (gc *GitCommit) String() string { builder.WriteString(", ") builder.WriteString("timestamp=") builder.WriteString(gc.Timestamp.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("pr_number=") + builder.WriteString(gc.PrNumber) builder.WriteByte(')') return builder.String() } diff --git a/ent/gitcommit/gitcommit.go b/ent/gitcommit/gitcommit.go index df2dbbc..c9c4e37 100644 --- a/ent/gitcommit/gitcommit.go +++ b/ent/gitcommit/gitcommit.go @@ -33,6 +33,8 @@ const ( FieldAuthor = "author" // FieldTimestamp holds the string denoting the timestamp field in the database. FieldTimestamp = "timestamp" + // FieldPrNumber holds the string denoting the pr_number field in the database. + FieldPrNumber = "pr_number" // EdgeResults holds the string denoting the results edge name in mutations. EdgeResults = "results" // Table holds the table name of the gitcommit in the database. @@ -58,6 +60,7 @@ var Columns = []string{ FieldCommitTimestamp, FieldAuthor, FieldTimestamp, + FieldPrNumber, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -134,6 +137,11 @@ func ByTimestamp(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTimestamp, opts...).ToFunc() } +// ByPrNumber orders the results by the pr_number field. +func ByPrNumber(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPrNumber, opts...).ToFunc() +} + // ByResultsCount orders the results by results count. func ByResultsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/ent/gitcommit/where.go b/ent/gitcommit/where.go index 494acf8..2b6f852 100644 --- a/ent/gitcommit/where.go +++ b/ent/gitcommit/where.go @@ -101,6 +101,11 @@ func Timestamp(v time.Time) predicate.GitCommit { return predicate.GitCommit(sql.FieldEQ(FieldTimestamp, v)) } +// PrNumber applies equality check predicate on the "pr_number" field. It's identical to PrNumberEQ. +func PrNumber(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldEQ(FieldPrNumber, v)) +} + // CreateTimeEQ applies the EQ predicate on the "create_time" field. func CreateTimeEQ(v time.Time) predicate.GitCommit { return predicate.GitCommit(sql.FieldEQ(FieldCreateTime, v)) @@ -606,6 +611,81 @@ func TimestampNotNil() predicate.GitCommit { return predicate.GitCommit(sql.FieldNotNull(FieldTimestamp)) } +// PrNumberEQ applies the EQ predicate on the "pr_number" field. +func PrNumberEQ(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldEQ(FieldPrNumber, v)) +} + +// PrNumberNEQ applies the NEQ predicate on the "pr_number" field. +func PrNumberNEQ(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldNEQ(FieldPrNumber, v)) +} + +// PrNumberIn applies the In predicate on the "pr_number" field. +func PrNumberIn(vs ...string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldIn(FieldPrNumber, vs...)) +} + +// PrNumberNotIn applies the NotIn predicate on the "pr_number" field. +func PrNumberNotIn(vs ...string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldNotIn(FieldPrNumber, vs...)) +} + +// PrNumberGT applies the GT predicate on the "pr_number" field. +func PrNumberGT(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldGT(FieldPrNumber, v)) +} + +// PrNumberGTE applies the GTE predicate on the "pr_number" field. +func PrNumberGTE(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldGTE(FieldPrNumber, v)) +} + +// PrNumberLT applies the LT predicate on the "pr_number" field. +func PrNumberLT(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldLT(FieldPrNumber, v)) +} + +// PrNumberLTE applies the LTE predicate on the "pr_number" field. +func PrNumberLTE(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldLTE(FieldPrNumber, v)) +} + +// PrNumberContains applies the Contains predicate on the "pr_number" field. +func PrNumberContains(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldContains(FieldPrNumber, v)) +} + +// PrNumberHasPrefix applies the HasPrefix predicate on the "pr_number" field. +func PrNumberHasPrefix(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldHasPrefix(FieldPrNumber, v)) +} + +// PrNumberHasSuffix applies the HasSuffix predicate on the "pr_number" field. +func PrNumberHasSuffix(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldHasSuffix(FieldPrNumber, v)) +} + +// PrNumberIsNil applies the IsNil predicate on the "pr_number" field. +func PrNumberIsNil() predicate.GitCommit { + return predicate.GitCommit(sql.FieldIsNull(FieldPrNumber)) +} + +// PrNumberNotNil applies the NotNil predicate on the "pr_number" field. +func PrNumberNotNil() predicate.GitCommit { + return predicate.GitCommit(sql.FieldNotNull(FieldPrNumber)) +} + +// PrNumberEqualFold applies the EqualFold predicate on the "pr_number" field. +func PrNumberEqualFold(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldEqualFold(FieldPrNumber, v)) +} + +// PrNumberContainsFold applies the ContainsFold predicate on the "pr_number" field. +func PrNumberContainsFold(v string) predicate.GitCommit { + return predicate.GitCommit(sql.FieldContainsFold(FieldPrNumber, v)) +} + // HasResults applies the HasEdge predicate on the "results" edge. func HasResults() predicate.GitCommit { return predicate.GitCommit(func(s *sql.Selector) { diff --git a/ent/gitcommit_create.go b/ent/gitcommit_create.go index 6e74e43..92b7bc4 100644 --- a/ent/gitcommit_create.go +++ b/ent/gitcommit_create.go @@ -111,6 +111,20 @@ func (gcc *GitCommitCreate) SetNillableTimestamp(t *time.Time) *GitCommitCreate return gcc } +// SetPrNumber sets the "pr_number" field. +func (gcc *GitCommitCreate) SetPrNumber(s string) *GitCommitCreate { + gcc.mutation.SetPrNumber(s) + return gcc +} + +// SetNillablePrNumber sets the "pr_number" field if the given value is not nil. +func (gcc *GitCommitCreate) SetNillablePrNumber(s *string) *GitCommitCreate { + if s != nil { + gcc.SetPrNumber(*s) + } + return gcc +} + // SetID sets the "id" field. func (gcc *GitCommitCreate) SetID(u uuid.UUID) *GitCommitCreate { gcc.mutation.SetID(u) @@ -284,6 +298,10 @@ func (gcc *GitCommitCreate) createSpec() (*GitCommit, *sqlgraph.CreateSpec) { _spec.SetField(gitcommit.FieldTimestamp, field.TypeTime, value) _node.Timestamp = value } + if value, ok := gcc.mutation.PrNumber(); ok { + _spec.SetField(gitcommit.FieldPrNumber, field.TypeString, value) + _node.PrNumber = value + } if nodes := gcc.mutation.ResultsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -460,6 +478,24 @@ func (u *GitCommitUpsert) ClearTimestamp() *GitCommitUpsert { return u } +// SetPrNumber sets the "pr_number" field. +func (u *GitCommitUpsert) SetPrNumber(v string) *GitCommitUpsert { + u.Set(gitcommit.FieldPrNumber, v) + return u +} + +// UpdatePrNumber sets the "pr_number" field to the value that was provided on create. +func (u *GitCommitUpsert) UpdatePrNumber() *GitCommitUpsert { + u.SetExcluded(gitcommit.FieldPrNumber) + return u +} + +// ClearPrNumber clears the value of the "pr_number" field. +func (u *GitCommitUpsert) ClearPrNumber() *GitCommitUpsert { + u.SetNull(gitcommit.FieldPrNumber) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -637,6 +673,27 @@ func (u *GitCommitUpsertOne) ClearTimestamp() *GitCommitUpsertOne { }) } +// SetPrNumber sets the "pr_number" field. +func (u *GitCommitUpsertOne) SetPrNumber(v string) *GitCommitUpsertOne { + return u.Update(func(s *GitCommitUpsert) { + s.SetPrNumber(v) + }) +} + +// UpdatePrNumber sets the "pr_number" field to the value that was provided on create. +func (u *GitCommitUpsertOne) UpdatePrNumber() *GitCommitUpsertOne { + return u.Update(func(s *GitCommitUpsert) { + s.UpdatePrNumber() + }) +} + +// ClearPrNumber clears the value of the "pr_number" field. +func (u *GitCommitUpsertOne) ClearPrNumber() *GitCommitUpsertOne { + return u.Update(func(s *GitCommitUpsert) { + s.ClearPrNumber() + }) +} + // Exec executes the query. func (u *GitCommitUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -981,6 +1038,27 @@ func (u *GitCommitUpsertBulk) ClearTimestamp() *GitCommitUpsertBulk { }) } +// SetPrNumber sets the "pr_number" field. +func (u *GitCommitUpsertBulk) SetPrNumber(v string) *GitCommitUpsertBulk { + return u.Update(func(s *GitCommitUpsert) { + s.SetPrNumber(v) + }) +} + +// UpdatePrNumber sets the "pr_number" field to the value that was provided on create. +func (u *GitCommitUpsertBulk) UpdatePrNumber() *GitCommitUpsertBulk { + return u.Update(func(s *GitCommitUpsert) { + s.UpdatePrNumber() + }) +} + +// ClearPrNumber clears the value of the "pr_number" field. +func (u *GitCommitUpsertBulk) ClearPrNumber() *GitCommitUpsertBulk { + return u.Update(func(s *GitCommitUpsert) { + s.ClearPrNumber() + }) +} + // Exec executes the query. func (u *GitCommitUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/ent/gitcommit_query.go b/ent/gitcommit_query.go index 37cac41..e6c7f0c 100644 --- a/ent/gitcommit_query.go +++ b/ent/gitcommit_query.go @@ -553,6 +553,12 @@ func (gcq *GitCommitQuery) ForShare(opts ...sql.LockOption) *GitCommitQuery { return gcq } +// Modify adds a query modifier for attaching custom logic to queries. +func (gcq *GitCommitQuery) Modify(modifiers ...func(s *sql.Selector)) *GitCommitSelect { + gcq.modifiers = append(gcq.modifiers, modifiers...) + return gcq.Select() +} + // GitCommitGroupBy is the group-by builder for GitCommit entities. type GitCommitGroupBy struct { selector @@ -642,3 +648,9 @@ func (gcs *GitCommitSelect) sqlScan(ctx context.Context, root *GitCommitQuery, v defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (gcs *GitCommitSelect) Modify(modifiers ...func(s *sql.Selector)) *GitCommitSelect { + gcs.modifiers = append(gcs.modifiers, modifiers...) + return gcs +} diff --git a/ent/gitcommit_update.go b/ent/gitcommit_update.go index bb347be..3f9126b 100644 --- a/ent/gitcommit_update.go +++ b/ent/gitcommit_update.go @@ -20,8 +20,9 @@ import ( // GitCommitUpdate is the builder for updating GitCommit entities. type GitCommitUpdate struct { config - hooks []Hook - mutation *GitCommitMutation + hooks []Hook + mutation *GitCommitMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the GitCommitUpdate builder. @@ -146,6 +147,26 @@ func (gcu *GitCommitUpdate) ClearTimestamp() *GitCommitUpdate { return gcu } +// SetPrNumber sets the "pr_number" field. +func (gcu *GitCommitUpdate) SetPrNumber(s string) *GitCommitUpdate { + gcu.mutation.SetPrNumber(s) + return gcu +} + +// SetNillablePrNumber sets the "pr_number" field if the given value is not nil. +func (gcu *GitCommitUpdate) SetNillablePrNumber(s *string) *GitCommitUpdate { + if s != nil { + gcu.SetPrNumber(*s) + } + return gcu +} + +// ClearPrNumber clears the value of the "pr_number" field. +func (gcu *GitCommitUpdate) ClearPrNumber() *GitCommitUpdate { + gcu.mutation.ClearPrNumber() + return gcu +} + // AddResultIDs adds the "results" edge to the CIWorkflowResult entity by IDs. func (gcu *GitCommitUpdate) AddResultIDs(ids ...uuid.UUID) *GitCommitUpdate { gcu.mutation.AddResultIDs(ids...) @@ -223,6 +244,12 @@ func (gcu *GitCommitUpdate) defaults() { } } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (gcu *GitCommitUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *GitCommitUpdate { + gcu.modifiers = append(gcu.modifiers, modifiers...) + return gcu +} + func (gcu *GitCommitUpdate) sqlSave(ctx context.Context) (n int, err error) { _spec := sqlgraph.NewUpdateSpec(gitcommit.Table, gitcommit.Columns, sqlgraph.NewFieldSpec(gitcommit.FieldID, field.TypeUUID)) if ps := gcu.mutation.predicates; len(ps) > 0 { @@ -262,6 +289,12 @@ func (gcu *GitCommitUpdate) sqlSave(ctx context.Context) (n int, err error) { if gcu.mutation.TimestampCleared() { _spec.ClearField(gitcommit.FieldTimestamp, field.TypeTime) } + if value, ok := gcu.mutation.PrNumber(); ok { + _spec.SetField(gitcommit.FieldPrNumber, field.TypeString, value) + } + if gcu.mutation.PrNumberCleared() { + _spec.ClearField(gitcommit.FieldPrNumber, field.TypeString) + } if gcu.mutation.ResultsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -307,6 +340,7 @@ func (gcu *GitCommitUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(gcu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, gcu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{gitcommit.Label} @@ -322,9 +356,10 @@ func (gcu *GitCommitUpdate) sqlSave(ctx context.Context) (n int, err error) { // GitCommitUpdateOne is the builder for updating a single GitCommit entity. type GitCommitUpdateOne struct { config - fields []string - hooks []Hook - mutation *GitCommitMutation + fields []string + hooks []Hook + mutation *GitCommitMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -443,6 +478,26 @@ func (gcuo *GitCommitUpdateOne) ClearTimestamp() *GitCommitUpdateOne { return gcuo } +// SetPrNumber sets the "pr_number" field. +func (gcuo *GitCommitUpdateOne) SetPrNumber(s string) *GitCommitUpdateOne { + gcuo.mutation.SetPrNumber(s) + return gcuo +} + +// SetNillablePrNumber sets the "pr_number" field if the given value is not nil. +func (gcuo *GitCommitUpdateOne) SetNillablePrNumber(s *string) *GitCommitUpdateOne { + if s != nil { + gcuo.SetPrNumber(*s) + } + return gcuo +} + +// ClearPrNumber clears the value of the "pr_number" field. +func (gcuo *GitCommitUpdateOne) ClearPrNumber() *GitCommitUpdateOne { + gcuo.mutation.ClearPrNumber() + return gcuo +} + // AddResultIDs adds the "results" edge to the CIWorkflowResult entity by IDs. func (gcuo *GitCommitUpdateOne) AddResultIDs(ids ...uuid.UUID) *GitCommitUpdateOne { gcuo.mutation.AddResultIDs(ids...) @@ -533,6 +588,12 @@ func (gcuo *GitCommitUpdateOne) defaults() { } } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (gcuo *GitCommitUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *GitCommitUpdateOne { + gcuo.modifiers = append(gcuo.modifiers, modifiers...) + return gcuo +} + func (gcuo *GitCommitUpdateOne) sqlSave(ctx context.Context) (_node *GitCommit, err error) { _spec := sqlgraph.NewUpdateSpec(gitcommit.Table, gitcommit.Columns, sqlgraph.NewFieldSpec(gitcommit.FieldID, field.TypeUUID)) id, ok := gcuo.mutation.ID() @@ -589,6 +650,12 @@ func (gcuo *GitCommitUpdateOne) sqlSave(ctx context.Context) (_node *GitCommit, if gcuo.mutation.TimestampCleared() { _spec.ClearField(gitcommit.FieldTimestamp, field.TypeTime) } + if value, ok := gcuo.mutation.PrNumber(); ok { + _spec.SetField(gitcommit.FieldPrNumber, field.TypeString, value) + } + if gcuo.mutation.PrNumberCleared() { + _spec.ClearField(gitcommit.FieldPrNumber, field.TypeString) + } if gcuo.mutation.ResultsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -634,6 +701,7 @@ func (gcuo *GitCommitUpdateOne) sqlSave(ctx context.Context) (_node *GitCommit, } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(gcuo.modifiers...) _node = &GitCommit{config: gcuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/hook/hook.go b/ent/hook/hook.go index fda7f47..4c792eb 100644 --- a/ent/hook/hook.go +++ b/ent/hook/hook.go @@ -44,6 +44,18 @@ func (f NodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.NodeMutation", m) } +// The NodeReviewFunc type is an adapter to allow the use of ordinary +// function as NodeReview mutator. +type NodeReviewFunc func(context.Context, *ent.NodeReviewMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f NodeReviewFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.NodeReviewMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.NodeReviewMutation", m) +} + // The NodeVersionFunc type is an adapter to allow the use of ordinary // function as NodeVersion mutator. type NodeVersionFunc func(context.Context, *ent.NodeVersionMutation) (ent.Value, error) diff --git a/ent/migrate/migrations/20240526144817_migration.sql b/ent/migrate/migrations/20240526144817_migration.sql new file mode 100644 index 0000000..ca3f050 --- /dev/null +++ b/ent/migrate/migrations/20240526144817_migration.sql @@ -0,0 +1,26 @@ +-- Create "git_commits" table +CREATE TABLE IF NOT EXISTS "git_commits" ("id" uuid NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "commit_hash" text NOT NULL, "branch_name" text NOT NULL, "repo_name" text NOT NULL, "commit_message" text NOT NULL, "commit_timestamp" timestamptz NOT NULL, "author" text NULL, "timestamp" timestamptz NULL, PRIMARY KEY ("id")); +-- Create index "gitcommit_repo_name_commit_hash" to table: "git_commits" +CREATE UNIQUE INDEX IF NOT EXISTS "gitcommit_repo_name_commit_hash" ON "git_commits" ("repo_name", "commit_hash"); +-- Create "storage_files" table +CREATE TABLE IF NOT EXISTS "storage_files" ("id" uuid NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "bucket_name" text NOT NULL, "object_name" text NULL, "file_path" text NOT NULL, "file_type" text NOT NULL, "file_url" text NULL, PRIMARY KEY ("id")); +-- Create "ci_workflow_results" table +CREATE TABLE IF NOT EXISTS "ci_workflow_results" ("id" uuid NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "operating_system" text NOT NULL, "gpu_type" text NULL, "pytorch_version" text NULL, "workflow_name" text NULL, "run_id" text NULL, "status" text NULL, "start_time" bigint NULL, "end_time" bigint NULL, "ci_workflow_result_storage_file" uuid NULL, "git_commit_results" uuid NULL, PRIMARY KEY ("id"), CONSTRAINT "ci_workflow_results_git_commits_results" FOREIGN KEY ("git_commit_results") REFERENCES "git_commits" ("id") ON UPDATE NO ACTION ON DELETE SET NULL, CONSTRAINT "ci_workflow_results_storage_files_storage_file" FOREIGN KEY ("ci_workflow_result_storage_file") REFERENCES "storage_files" ("id") ON UPDATE NO ACTION ON DELETE SET NULL); +-- Create "publishers" table +CREATE TABLE IF NOT EXISTS "publishers" ("id" text NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "name" text NOT NULL, "description" text NULL, "website" text NULL, "support_email" text NULL, "source_code_repo" text NULL, "logo_url" text NULL, PRIMARY KEY ("id")); +-- Create "nodes" table +CREATE TABLE IF NOT EXISTS "nodes" ("id" text NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "name" text NOT NULL, "description" text NULL, "author" text NULL, "license" text NOT NULL, "repository_url" text NOT NULL, "icon_url" text NULL, "tags" text NOT NULL, "publisher_id" text NOT NULL, PRIMARY KEY ("id"), CONSTRAINT "nodes_publishers_nodes" FOREIGN KEY ("publisher_id") REFERENCES "publishers" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION); +-- Create "node_versions" table +CREATE TABLE IF NOT EXISTS "node_versions" ("id" uuid NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "version" text NOT NULL, "changelog" text NULL, "pip_dependencies" text NOT NULL, "deprecated" boolean NOT NULL DEFAULT false, "node_id" text NOT NULL, "node_version_storage_file" uuid NULL, PRIMARY KEY ("id"), CONSTRAINT "node_versions_nodes_versions" FOREIGN KEY ("node_id") REFERENCES "nodes" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION, CONSTRAINT "node_versions_storage_files_storage_file" FOREIGN KEY ("node_version_storage_file") REFERENCES "storage_files" ("id") ON UPDATE NO ACTION ON DELETE SET NULL); +-- Create index "nodeversion_node_id_version" to table: "node_versions" +CREATE UNIQUE INDEX IF NOT EXISTS "nodeversion_node_id_version" ON "node_versions" ("node_id", "version"); +-- Create "personal_access_tokens" table +CREATE TABLE IF NOT EXISTS "personal_access_tokens" ("id" uuid NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "name" text NOT NULL, "description" text NOT NULL, "token" text NOT NULL, "publisher_id" text NOT NULL, PRIMARY KEY ("id"), CONSTRAINT "personal_access_tokens_publishers_personal_access_tokens" FOREIGN KEY ("publisher_id") REFERENCES "publishers" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION); +-- Create index "personal_access_tokens_token_key" to table: "personal_access_tokens" +CREATE UNIQUE INDEX IF NOT EXISTS "personal_access_tokens_token_key" ON "personal_access_tokens" ("token"); +-- Create index "personalaccesstoken_token" to table: "personal_access_tokens" +CREATE UNIQUE INDEX IF NOT EXISTS "personalaccesstoken_token" ON "personal_access_tokens" ("token"); +-- Create "users" table +CREATE TABLE IF NOT EXISTS "users" ("id" character varying NOT NULL, "create_time" timestamptz NOT NULL, "update_time" timestamptz NOT NULL, "email" character varying NULL, "name" character varying NULL, "is_approved" boolean NOT NULL DEFAULT false, "is_admin" boolean NOT NULL DEFAULT false, PRIMARY KEY ("id")); +-- Create "publisher_permissions" table +CREATE TABLE IF NOT EXISTS "publisher_permissions" ("id" bigint NOT NULL GENERATED BY DEFAULT AS IDENTITY, "permission" character varying NOT NULL, "publisher_id" text NOT NULL, "user_id" character varying NOT NULL, PRIMARY KEY ("id"), CONSTRAINT "publisher_permissions_publishers_publisher_permissions" FOREIGN KEY ("publisher_id") REFERENCES "publishers" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION, CONSTRAINT "publisher_permissions_users_publisher_permissions" FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION); diff --git a/ent/migrate/migrations/20240528220411_migration.sql b/ent/migrate/migrations/20240528220411_migration.sql new file mode 100644 index 0000000..720e9ee --- /dev/null +++ b/ent/migrate/migrations/20240528220411_migration.sql @@ -0,0 +1,2 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" ADD COLUMN "test_field" text; diff --git a/ent/migrate/migrations/20240528221846_migration.sql b/ent/migrate/migrations/20240528221846_migration.sql new file mode 100644 index 0000000..058cb32 --- /dev/null +++ b/ent/migrate/migrations/20240528221846_migration.sql @@ -0,0 +1,2 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" ALTER COLUMN "test_field" DROP NOT NULL; diff --git a/ent/migrate/migrations/20240528222851_migration.sql b/ent/migrate/migrations/20240528222851_migration.sql new file mode 100644 index 0000000..cc593e4 --- /dev/null +++ b/ent/migrate/migrations/20240528222851_migration.sql @@ -0,0 +1,2 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" DROP COLUMN "test_field"; diff --git a/ent/migrate/migrations/20240601211932_migration.sql b/ent/migrate/migrations/20240601211932_migration.sql new file mode 100644 index 0000000..1ec88da --- /dev/null +++ b/ent/migrate/migrations/20240601211932_migration.sql @@ -0,0 +1,6 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" ADD COLUMN "total_install" bigint NOT NULL DEFAULT 0, ADD COLUMN "total_star" bigint NOT NULL DEFAULT 0, ADD COLUMN "total_review" bigint NOT NULL DEFAULT 0; +-- Create "node_reviews" table +CREATE TABLE "node_reviews" ("id" uuid NOT NULL, "star" bigint NOT NULL DEFAULT 0, "node_id" text NOT NULL, "user_id" character varying NOT NULL, PRIMARY KEY ("id"), CONSTRAINT "node_reviews_nodes_reviews" FOREIGN KEY ("node_id") REFERENCES "nodes" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION, CONSTRAINT "node_reviews_users_reviews" FOREIGN KEY ("user_id") REFERENCES "users" ("id") ON UPDATE NO ACTION ON DELETE NO ACTION); +-- Create index "nodereview_node_id_user_id" to table: "node_reviews" +CREATE UNIQUE INDEX "nodereview_node_id_user_id" ON "node_reviews" ("node_id", "user_id"); diff --git a/ent/migrate/migrations/20240613231838_migration.sql b/ent/migrate/migrations/20240613231838_migration.sql new file mode 100644 index 0000000..4d6d985 --- /dev/null +++ b/ent/migrate/migrations/20240613231838_migration.sql @@ -0,0 +1,2 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" ADD COLUMN "status" character varying NOT NULL DEFAULT 'pending'; diff --git a/ent/migrate/migrations/20240614133606_migration.sql b/ent/migrate/migrations/20240614133606_migration.sql new file mode 100644 index 0000000..ddd402d --- /dev/null +++ b/ent/migrate/migrations/20240614133606_migration.sql @@ -0,0 +1,4 @@ +-- Modify "publishers" table +ALTER TABLE "publishers" ADD COLUMN "status" character varying NOT NULL DEFAULT 'ACTIVE'; +-- Modify "users" table +ALTER TABLE "users" ADD COLUMN "status" character varying NOT NULL DEFAULT 'ACTIVE'; diff --git a/ent/migrate/migrations/20240614211957_migration.sql b/ent/migrate/migrations/20240614211957_migration.sql new file mode 100644 index 0000000..46f02af --- /dev/null +++ b/ent/migrate/migrations/20240614211957_migration.sql @@ -0,0 +1,2 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" ALTER COLUMN "status" SET DEFAULT 'active'; diff --git a/ent/migrate/migrations/20240614213347_migration.sql b/ent/migrate/migrations/20240614213347_migration.sql new file mode 100644 index 0000000..6920c11 --- /dev/null +++ b/ent/migrate/migrations/20240614213347_migration.sql @@ -0,0 +1,2 @@ +-- Modify "node_versions" table +ALTER TABLE "node_versions" ADD COLUMN "status" character varying NOT NULL DEFAULT 'pending'; diff --git a/ent/migrate/migrations/20240615214552_migration.sql b/ent/migrate/migrations/20240615214552_migration.sql new file mode 100644 index 0000000..7b9a0ec --- /dev/null +++ b/ent/migrate/migrations/20240615214552_migration.sql @@ -0,0 +1,2 @@ +-- Modify "nodes" table +ALTER TABLE "nodes" ADD COLUMN "category" text NULL, ADD COLUMN "status_detail" text NULL; diff --git a/ent/migrate/migrations/20240615222435_migration.sql b/ent/migrate/migrations/20240615222435_migration.sql new file mode 100644 index 0000000..6440c8c --- /dev/null +++ b/ent/migrate/migrations/20240615222435_migration.sql @@ -0,0 +1,2 @@ +-- Modify "node_versions" table +ALTER TABLE "node_versions" ADD COLUMN "status_reason" text NOT NULL DEFAULT ''; diff --git a/ent/migrate/migrations/20240710235003_migration.sql b/ent/migrate/migrations/20240710235003_migration.sql new file mode 100644 index 0000000..c78d05c --- /dev/null +++ b/ent/migrate/migrations/20240710235003_migration.sql @@ -0,0 +1,7 @@ +UPDATE "ci_workflow_results" +SET "status" = 'STARTED' +WHERE "status" IS NULL; +-- Modify "ci_workflow_results" table +ALTER TABLE "ci_workflow_results" ALTER COLUMN "status" TYPE character varying, ALTER COLUMN "status" SET NOT NULL, ALTER COLUMN "status" SET DEFAULT 'STARTED', ADD COLUMN "python_version" text NULL, ADD COLUMN "vram" bigint NULL, ADD COLUMN "job_trigger_user" text NULL; +-- Modify "git_commits" table +ALTER TABLE "git_commits" ADD COLUMN "pr_number" text NULL; diff --git a/ent/migrate/migrations/20240711000828_migration.sql b/ent/migrate/migrations/20240711000828_migration.sql new file mode 100644 index 0000000..5a541d7 --- /dev/null +++ b/ent/migrate/migrations/20240711000828_migration.sql @@ -0,0 +1,4 @@ +-- Modify "ci_workflow_results" table +ALTER TABLE "ci_workflow_results" ADD COLUMN "peak_vram" bigint NULL; +-- Rename a column from "vram" to "avg_vram" +ALTER TABLE "ci_workflow_results" RENAME COLUMN "vram" TO "avg_vram"; diff --git a/ent/migrate/migrations/20240711020205_migration.sql b/ent/migrate/migrations/20240711020205_migration.sql new file mode 100644 index 0000000..2876338 --- /dev/null +++ b/ent/migrate/migrations/20240711020205_migration.sql @@ -0,0 +1,4 @@ +-- Modify "ci_workflow_results" table +ALTER TABLE "ci_workflow_results" DROP COLUMN "ci_workflow_result_storage_file"; +-- Modify "storage_files" table +ALTER TABLE "storage_files" ADD COLUMN "ci_workflow_result_storage_file" uuid NULL, ADD CONSTRAINT "storage_files_ci_workflow_results_storage_file" FOREIGN KEY ("ci_workflow_result_storage_file") REFERENCES "ci_workflow_results" ("id") ON UPDATE NO ACTION ON DELETE SET NULL; diff --git a/ent/migrate/migrations/20240711181827_migration.sql b/ent/migrate/migrations/20240711181827_migration.sql new file mode 100644 index 0000000..3ec085b --- /dev/null +++ b/ent/migrate/migrations/20240711181827_migration.sql @@ -0,0 +1,2 @@ +-- Modify "ci_workflow_results" table +ALTER TABLE "ci_workflow_results" ADD COLUMN "metadata" jsonb NULL; diff --git a/ent/migrate/migrations/20240711235600_migration.sql b/ent/migrate/migrations/20240711235600_migration.sql new file mode 100644 index 0000000..2b513b4 --- /dev/null +++ b/ent/migrate/migrations/20240711235600_migration.sql @@ -0,0 +1,2 @@ +-- Modify "ci_workflow_results" table +ALTER TABLE "ci_workflow_results" DROP COLUMN "gpu_type", ADD COLUMN "job_id" text NULL, ADD COLUMN "cuda_version" text NULL; diff --git a/ent/migrate/migrations/20240712001321_migration.sql b/ent/migrate/migrations/20240712001321_migration.sql new file mode 100644 index 0000000..e218762 --- /dev/null +++ b/ent/migrate/migrations/20240712001321_migration.sql @@ -0,0 +1,2 @@ +-- Modify "ci_workflow_results" table +ALTER TABLE "ci_workflow_results" ADD COLUMN "comfy_run_flags" text NULL; diff --git a/ent/migrate/migrations/atlas.sum b/ent/migrate/migrations/atlas.sum new file mode 100644 index 0000000..cb9b83f --- /dev/null +++ b/ent/migrate/migrations/atlas.sum @@ -0,0 +1,18 @@ +h1:pgh6GSexz56KIn8mQB7L61FtVDJP9QfFFwN6SMoqMwA= +20240526144817_migration.sql h1:sP6keX+oMyLL2qpIFx0Ns0WYfWM5hJ4zkFPmLWT68fM= +20240528220411_migration.sql h1:SR44sOEaWbDgYCKJZIKcGCI7Ta+LqL71z225Nhs2+HM= +20240528221846_migration.sql h1:EkUonGI9Bu689qWX4pG3PRC+On4f6u7UvwDbaR8mCNk= +20240528222851_migration.sql h1:VaQhEaDGe8M2kuNtKVjuMWMLJ9RhJVraVgSM4rm/XcQ= +20240601211932_migration.sql h1:zTofjRbLfoZZF8k6dvAMUDJGRENHAG3m1qqtLgTXUCQ= +20240613231838_migration.sql h1:fbgEWDDA7hcQP5gdX6WDpPWG+340dT40CgkwIG7KGxw= +20240614133606_migration.sql h1:ydCsnRUhSVP0WY7ey7DNHcNQnABPuOt7Z/b0mj7+s6M= +20240614211957_migration.sql h1:qUD4bs8JikDO/0TLD91fhkHzyesdyn2yKqmwkBcWVMc= +20240614213347_migration.sql h1:iz6Lzy1OdtArw2xgRICa0Kktj+BmBji0JUbi+k5WlU8= +20240615214552_migration.sql h1:KEmZcB8I6dlnCp4fXwcUETnxUBVf9JY53CUgpkX/jnw= +20240615222435_migration.sql h1:n4J0/FWg+mu6kAUTtHMiwG1mveFMsf2NT/jkhtczo48= +20240710235003_migration.sql h1:YsK6hRl55LOfXExL/h2yLkRUuRup3pvJoavi9XjhQ54= +20240711000828_migration.sql h1:0499gZvTBMtGMhIPdN8RcouyluWleX++sbnvfO5R5jE= +20240711020205_migration.sql h1:Ubcs7O+oFT6RC0uiWPIyZKRz3Iyfv6g8t3WDohrgla8= +20240711181827_migration.sql h1:X03aAvDOn41Sp66cN/Xr0mv9zm34cXHq18L9U4Y8ZkU= +20240711235600_migration.sql h1:vT82ssDcIMCHPQQdaJjqovKTXg6amMp5W+lz7ilIowY= +20240712001321_migration.sql h1:fK0ePEw3Xr82hRFD/634KLX2tmczgvbUE6KMVlT/4fc= diff --git a/ent/migrate/schema.go b/ent/migrate/schema.go index 87c9449..dea74aa 100644 --- a/ent/migrate/schema.go +++ b/ent/migrate/schema.go @@ -14,14 +14,20 @@ var ( {Name: "create_time", Type: field.TypeTime}, {Name: "update_time", Type: field.TypeTime}, {Name: "operating_system", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, - {Name: "gpu_type", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, - {Name: "pytorch_version", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "workflow_name", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "run_id", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, - {Name: "status", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "job_id", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "status", Type: field.TypeString, Default: "STARTED"}, {Name: "start_time", Type: field.TypeInt64, Nullable: true}, {Name: "end_time", Type: field.TypeInt64, Nullable: true}, - {Name: "ci_workflow_result_storage_file", Type: field.TypeUUID, Nullable: true}, + {Name: "python_version", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "pytorch_version", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "cuda_version", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "comfy_run_flags", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "avg_vram", Type: field.TypeInt, Nullable: true}, + {Name: "peak_vram", Type: field.TypeInt, Nullable: true}, + {Name: "job_trigger_user", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "metadata", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "git_commit_results", Type: field.TypeUUID, Nullable: true}, } // CiWorkflowResultsTable holds the schema information for the "ci_workflow_results" table. @@ -30,15 +36,9 @@ var ( Columns: CiWorkflowResultsColumns, PrimaryKey: []*schema.Column{CiWorkflowResultsColumns[0]}, ForeignKeys: []*schema.ForeignKey{ - { - Symbol: "ci_workflow_results_storage_files_storage_file", - Columns: []*schema.Column{CiWorkflowResultsColumns[11]}, - RefColumns: []*schema.Column{StorageFilesColumns[0]}, - OnDelete: schema.SetNull, - }, { Symbol: "ci_workflow_results_git_commits_results", - Columns: []*schema.Column{CiWorkflowResultsColumns[12]}, + Columns: []*schema.Column{CiWorkflowResultsColumns[18]}, RefColumns: []*schema.Column{GitCommitsColumns[0]}, OnDelete: schema.SetNull, }, @@ -56,6 +56,7 @@ var ( {Name: "commit_timestamp", Type: field.TypeTime}, {Name: "author", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "timestamp", Type: field.TypeTime, Nullable: true}, + {Name: "pr_number", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, } // GitCommitsTable holds the schema information for the "git_commits" table. GitCommitsTable = &schema.Table{ @@ -77,11 +78,17 @@ var ( {Name: "update_time", Type: field.TypeTime}, {Name: "name", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "description", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "category", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "author", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "license", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "repository_url", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "icon_url", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "tags", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "total_install", Type: field.TypeInt64, Default: 0}, + {Name: "total_star", Type: field.TypeInt64, Default: 0}, + {Name: "total_review", Type: field.TypeInt64, Default: 0}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "banned", "deleted"}, Default: "active"}, + {Name: "status_detail", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "publisher_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, } // NodesTable holds the schema information for the "nodes" table. @@ -92,12 +99,46 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "nodes_publishers_nodes", - Columns: []*schema.Column{NodesColumns[10]}, + Columns: []*schema.Column{NodesColumns[16]}, RefColumns: []*schema.Column{PublishersColumns[0]}, OnDelete: schema.NoAction, }, }, } + // NodeReviewsColumns holds the columns for the "node_reviews" table. + NodeReviewsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeUUID}, + {Name: "star", Type: field.TypeInt, Default: 0}, + {Name: "node_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "user_id", Type: field.TypeString}, + } + // NodeReviewsTable holds the schema information for the "node_reviews" table. + NodeReviewsTable = &schema.Table{ + Name: "node_reviews", + Columns: NodeReviewsColumns, + PrimaryKey: []*schema.Column{NodeReviewsColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "node_reviews_nodes_reviews", + Columns: []*schema.Column{NodeReviewsColumns[2]}, + RefColumns: []*schema.Column{NodesColumns[0]}, + OnDelete: schema.NoAction, + }, + { + Symbol: "node_reviews_users_reviews", + Columns: []*schema.Column{NodeReviewsColumns[3]}, + RefColumns: []*schema.Column{UsersColumns[0]}, + OnDelete: schema.NoAction, + }, + }, + Indexes: []*schema.Index{ + { + Name: "nodereview_node_id_user_id", + Unique: true, + Columns: []*schema.Column{NodeReviewsColumns[2], NodeReviewsColumns[3]}, + }, + }, + } // NodeVersionsColumns holds the columns for the "node_versions" table. NodeVersionsColumns = []*schema.Column{ {Name: "id", Type: field.TypeUUID}, @@ -107,6 +148,8 @@ var ( {Name: "changelog", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "pip_dependencies", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "text"}}, {Name: "deprecated", Type: field.TypeBool, Default: false}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"active", "banned", "deleted", "pending", "flagged"}, Default: "pending"}, + {Name: "status_reason", Type: field.TypeString, Default: "", SchemaType: map[string]string{"postgres": "text"}}, {Name: "node_id", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "node_version_storage_file", Type: field.TypeUUID, Nullable: true}, } @@ -118,13 +161,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "node_versions_nodes_versions", - Columns: []*schema.Column{NodeVersionsColumns[7]}, + Columns: []*schema.Column{NodeVersionsColumns[9]}, RefColumns: []*schema.Column{NodesColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "node_versions_storage_files_storage_file", - Columns: []*schema.Column{NodeVersionsColumns[8]}, + Columns: []*schema.Column{NodeVersionsColumns[10]}, RefColumns: []*schema.Column{StorageFilesColumns[0]}, OnDelete: schema.SetNull, }, @@ -133,7 +176,7 @@ var ( { Name: "nodeversion_node_id_version", Unique: true, - Columns: []*schema.Column{NodeVersionsColumns[7], NodeVersionsColumns[3]}, + Columns: []*schema.Column{NodeVersionsColumns[9], NodeVersionsColumns[3]}, }, }, } @@ -179,6 +222,7 @@ var ( {Name: "support_email", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "source_code_repo", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "logo_url", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"ACTIVE", "BANNED"}, Default: "ACTIVE"}, } // PublishersTable holds the schema information for the "publishers" table. PublishersTable = &schema.Table{ @@ -223,12 +267,21 @@ var ( {Name: "file_path", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "file_type", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, {Name: "file_url", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, + {Name: "ci_workflow_result_storage_file", Type: field.TypeUUID, Nullable: true}, } // StorageFilesTable holds the schema information for the "storage_files" table. StorageFilesTable = &schema.Table{ Name: "storage_files", Columns: StorageFilesColumns, PrimaryKey: []*schema.Column{StorageFilesColumns[0]}, + ForeignKeys: []*schema.ForeignKey{ + { + Symbol: "storage_files_ci_workflow_results_storage_file", + Columns: []*schema.Column{StorageFilesColumns[8]}, + RefColumns: []*schema.Column{CiWorkflowResultsColumns[0]}, + OnDelete: schema.SetNull, + }, + }, } // UsersColumns holds the columns for the "users" table. UsersColumns = []*schema.Column{ @@ -239,6 +292,7 @@ var ( {Name: "name", Type: field.TypeString, Nullable: true}, {Name: "is_approved", Type: field.TypeBool, Default: false}, {Name: "is_admin", Type: field.TypeBool, Default: false}, + {Name: "status", Type: field.TypeEnum, Enums: []string{"ACTIVE", "BANNED"}, Default: "ACTIVE"}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ @@ -251,6 +305,7 @@ var ( CiWorkflowResultsTable, GitCommitsTable, NodesTable, + NodeReviewsTable, NodeVersionsTable, PersonalAccessTokensTable, PublishersTable, @@ -261,12 +316,14 @@ var ( ) func init() { - CiWorkflowResultsTable.ForeignKeys[0].RefTable = StorageFilesTable - CiWorkflowResultsTable.ForeignKeys[1].RefTable = GitCommitsTable + CiWorkflowResultsTable.ForeignKeys[0].RefTable = GitCommitsTable NodesTable.ForeignKeys[0].RefTable = PublishersTable + NodeReviewsTable.ForeignKeys[0].RefTable = NodesTable + NodeReviewsTable.ForeignKeys[1].RefTable = UsersTable NodeVersionsTable.ForeignKeys[0].RefTable = NodesTable NodeVersionsTable.ForeignKeys[1].RefTable = StorageFilesTable PersonalAccessTokensTable.ForeignKeys[0].RefTable = PublishersTable PublisherPermissionsTable.ForeignKeys[0].RefTable = PublishersTable PublisherPermissionsTable.ForeignKeys[1].RefTable = UsersTable + StorageFilesTable.ForeignKeys[0].RefTable = CiWorkflowResultsTable } diff --git a/ent/mutation.go b/ent/mutation.go index 8a60a91..fa6b9e6 100644 --- a/ent/mutation.go +++ b/ent/mutation.go @@ -9,6 +9,7 @@ import ( "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/personalaccesstoken" "registry-backend/ent/predicate" @@ -37,6 +38,7 @@ const ( TypeCIWorkflowResult = "CIWorkflowResult" TypeGitCommit = "GitCommit" TypeNode = "Node" + TypeNodeReview = "NodeReview" TypeNodeVersion = "NodeVersion" TypePersonalAccessToken = "PersonalAccessToken" TypePublisher = "Publisher" @@ -54,19 +56,29 @@ type CIWorkflowResultMutation struct { create_time *time.Time update_time *time.Time operating_system *string - gpu_type *string - pytorch_version *string workflow_name *string run_id *string - status *string + job_id *string + status *schema.WorkflowRunStatusType start_time *int64 addstart_time *int64 end_time *int64 addend_time *int64 + python_version *string + pytorch_version *string + cuda_version *string + comfy_run_flags *string + avg_vram *int + addavg_vram *int + peak_vram *int + addpeak_vram *int + job_trigger_user *string + metadata *map[string]interface{} clearedFields map[string]struct{} gitcommit *uuid.UUID clearedgitcommit bool - storage_file *uuid.UUID + storage_file map[uuid.UUID]struct{} + removedstorage_file map[uuid.UUID]struct{} clearedstorage_file bool done bool oldValue func(context.Context) (*CIWorkflowResult, error) @@ -285,104 +297,6 @@ func (m *CIWorkflowResultMutation) ResetOperatingSystem() { m.operating_system = nil } -// SetGpuType sets the "gpu_type" field. -func (m *CIWorkflowResultMutation) SetGpuType(s string) { - m.gpu_type = &s -} - -// GpuType returns the value of the "gpu_type" field in the mutation. -func (m *CIWorkflowResultMutation) GpuType() (r string, exists bool) { - v := m.gpu_type - if v == nil { - return - } - return *v, true -} - -// OldGpuType returns the old "gpu_type" field's value of the CIWorkflowResult entity. -// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *CIWorkflowResultMutation) OldGpuType(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldGpuType is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldGpuType requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldGpuType: %w", err) - } - return oldValue.GpuType, nil -} - -// ClearGpuType clears the value of the "gpu_type" field. -func (m *CIWorkflowResultMutation) ClearGpuType() { - m.gpu_type = nil - m.clearedFields[ciworkflowresult.FieldGpuType] = struct{}{} -} - -// GpuTypeCleared returns if the "gpu_type" field was cleared in this mutation. -func (m *CIWorkflowResultMutation) GpuTypeCleared() bool { - _, ok := m.clearedFields[ciworkflowresult.FieldGpuType] - return ok -} - -// ResetGpuType resets all changes to the "gpu_type" field. -func (m *CIWorkflowResultMutation) ResetGpuType() { - m.gpu_type = nil - delete(m.clearedFields, ciworkflowresult.FieldGpuType) -} - -// SetPytorchVersion sets the "pytorch_version" field. -func (m *CIWorkflowResultMutation) SetPytorchVersion(s string) { - m.pytorch_version = &s -} - -// PytorchVersion returns the value of the "pytorch_version" field in the mutation. -func (m *CIWorkflowResultMutation) PytorchVersion() (r string, exists bool) { - v := m.pytorch_version - if v == nil { - return - } - return *v, true -} - -// OldPytorchVersion returns the old "pytorch_version" field's value of the CIWorkflowResult entity. -// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *CIWorkflowResultMutation) OldPytorchVersion(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPytorchVersion is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPytorchVersion requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPytorchVersion: %w", err) - } - return oldValue.PytorchVersion, nil -} - -// ClearPytorchVersion clears the value of the "pytorch_version" field. -func (m *CIWorkflowResultMutation) ClearPytorchVersion() { - m.pytorch_version = nil - m.clearedFields[ciworkflowresult.FieldPytorchVersion] = struct{}{} -} - -// PytorchVersionCleared returns if the "pytorch_version" field was cleared in this mutation. -func (m *CIWorkflowResultMutation) PytorchVersionCleared() bool { - _, ok := m.clearedFields[ciworkflowresult.FieldPytorchVersion] - return ok -} - -// ResetPytorchVersion resets all changes to the "pytorch_version" field. -func (m *CIWorkflowResultMutation) ResetPytorchVersion() { - m.pytorch_version = nil - delete(m.clearedFields, ciworkflowresult.FieldPytorchVersion) -} - // SetWorkflowName sets the "workflow_name" field. func (m *CIWorkflowResultMutation) SetWorkflowName(s string) { m.workflow_name = &s @@ -481,13 +395,62 @@ func (m *CIWorkflowResultMutation) ResetRunID() { delete(m.clearedFields, ciworkflowresult.FieldRunID) } +// SetJobID sets the "job_id" field. +func (m *CIWorkflowResultMutation) SetJobID(s string) { + m.job_id = &s +} + +// JobID returns the value of the "job_id" field in the mutation. +func (m *CIWorkflowResultMutation) JobID() (r string, exists bool) { + v := m.job_id + if v == nil { + return + } + return *v, true +} + +// OldJobID returns the old "job_id" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldJobID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldJobID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldJobID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldJobID: %w", err) + } + return oldValue.JobID, nil +} + +// ClearJobID clears the value of the "job_id" field. +func (m *CIWorkflowResultMutation) ClearJobID() { + m.job_id = nil + m.clearedFields[ciworkflowresult.FieldJobID] = struct{}{} +} + +// JobIDCleared returns if the "job_id" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) JobIDCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldJobID] + return ok +} + +// ResetJobID resets all changes to the "job_id" field. +func (m *CIWorkflowResultMutation) ResetJobID() { + m.job_id = nil + delete(m.clearedFields, ciworkflowresult.FieldJobID) +} + // SetStatus sets the "status" field. -func (m *CIWorkflowResultMutation) SetStatus(s string) { - m.status = &s +func (m *CIWorkflowResultMutation) SetStatus(srst schema.WorkflowRunStatusType) { + m.status = &srst } // Status returns the value of the "status" field in the mutation. -func (m *CIWorkflowResultMutation) Status() (r string, exists bool) { +func (m *CIWorkflowResultMutation) Status() (r schema.WorkflowRunStatusType, exists bool) { v := m.status if v == nil { return @@ -498,7 +461,7 @@ func (m *CIWorkflowResultMutation) Status() (r string, exists bool) { // OldStatus returns the old "status" field's value of the CIWorkflowResult entity. // If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *CIWorkflowResultMutation) OldStatus(ctx context.Context) (v string, err error) { +func (m *CIWorkflowResultMutation) OldStatus(ctx context.Context) (v schema.WorkflowRunStatusType, err error) { if !m.op.Is(OpUpdateOne) { return v, errors.New("OldStatus is only allowed on UpdateOne operations") } @@ -512,22 +475,9 @@ func (m *CIWorkflowResultMutation) OldStatus(ctx context.Context) (v string, err return oldValue.Status, nil } -// ClearStatus clears the value of the "status" field. -func (m *CIWorkflowResultMutation) ClearStatus() { - m.status = nil - m.clearedFields[ciworkflowresult.FieldStatus] = struct{}{} -} - -// StatusCleared returns if the "status" field was cleared in this mutation. -func (m *CIWorkflowResultMutation) StatusCleared() bool { - _, ok := m.clearedFields[ciworkflowresult.FieldStatus] - return ok -} - // ResetStatus resets all changes to the "status" field. func (m *CIWorkflowResultMutation) ResetStatus() { m.status = nil - delete(m.clearedFields, ciworkflowresult.FieldStatus) } // SetStartTime sets the "start_time" field. @@ -670,2065 +620,4034 @@ func (m *CIWorkflowResultMutation) ResetEndTime() { delete(m.clearedFields, ciworkflowresult.FieldEndTime) } -// SetGitcommitID sets the "gitcommit" edge to the GitCommit entity by id. -func (m *CIWorkflowResultMutation) SetGitcommitID(id uuid.UUID) { - m.gitcommit = &id +// SetPythonVersion sets the "python_version" field. +func (m *CIWorkflowResultMutation) SetPythonVersion(s string) { + m.python_version = &s } -// ClearGitcommit clears the "gitcommit" edge to the GitCommit entity. -func (m *CIWorkflowResultMutation) ClearGitcommit() { - m.clearedgitcommit = true +// PythonVersion returns the value of the "python_version" field in the mutation. +func (m *CIWorkflowResultMutation) PythonVersion() (r string, exists bool) { + v := m.python_version + if v == nil { + return + } + return *v, true } -// GitcommitCleared reports if the "gitcommit" edge to the GitCommit entity was cleared. -func (m *CIWorkflowResultMutation) GitcommitCleared() bool { - return m.clearedgitcommit +// OldPythonVersion returns the old "python_version" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldPythonVersion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPythonVersion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPythonVersion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPythonVersion: %w", err) + } + return oldValue.PythonVersion, nil } -// GitcommitID returns the "gitcommit" edge ID in the mutation. -func (m *CIWorkflowResultMutation) GitcommitID() (id uuid.UUID, exists bool) { - if m.gitcommit != nil { - return *m.gitcommit, true - } - return +// ClearPythonVersion clears the value of the "python_version" field. +func (m *CIWorkflowResultMutation) ClearPythonVersion() { + m.python_version = nil + m.clearedFields[ciworkflowresult.FieldPythonVersion] = struct{}{} } -// GitcommitIDs returns the "gitcommit" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// GitcommitID instead. It exists only for internal usage by the builders. -func (m *CIWorkflowResultMutation) GitcommitIDs() (ids []uuid.UUID) { - if id := m.gitcommit; id != nil { - ids = append(ids, *id) +// PythonVersionCleared returns if the "python_version" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) PythonVersionCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldPythonVersion] + return ok +} + +// ResetPythonVersion resets all changes to the "python_version" field. +func (m *CIWorkflowResultMutation) ResetPythonVersion() { + m.python_version = nil + delete(m.clearedFields, ciworkflowresult.FieldPythonVersion) +} + +// SetPytorchVersion sets the "pytorch_version" field. +func (m *CIWorkflowResultMutation) SetPytorchVersion(s string) { + m.pytorch_version = &s +} + +// PytorchVersion returns the value of the "pytorch_version" field in the mutation. +func (m *CIWorkflowResultMutation) PytorchVersion() (r string, exists bool) { + v := m.pytorch_version + if v == nil { + return } - return + return *v, true } -// ResetGitcommit resets all changes to the "gitcommit" edge. -func (m *CIWorkflowResultMutation) ResetGitcommit() { - m.gitcommit = nil - m.clearedgitcommit = false +// OldPytorchVersion returns the old "pytorch_version" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldPytorchVersion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPytorchVersion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPytorchVersion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPytorchVersion: %w", err) + } + return oldValue.PytorchVersion, nil } -// SetStorageFileID sets the "storage_file" edge to the StorageFile entity by id. -func (m *CIWorkflowResultMutation) SetStorageFileID(id uuid.UUID) { - m.storage_file = &id +// ClearPytorchVersion clears the value of the "pytorch_version" field. +func (m *CIWorkflowResultMutation) ClearPytorchVersion() { + m.pytorch_version = nil + m.clearedFields[ciworkflowresult.FieldPytorchVersion] = struct{}{} } -// ClearStorageFile clears the "storage_file" edge to the StorageFile entity. -func (m *CIWorkflowResultMutation) ClearStorageFile() { - m.clearedstorage_file = true +// PytorchVersionCleared returns if the "pytorch_version" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) PytorchVersionCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldPytorchVersion] + return ok } -// StorageFileCleared reports if the "storage_file" edge to the StorageFile entity was cleared. -func (m *CIWorkflowResultMutation) StorageFileCleared() bool { - return m.clearedstorage_file +// ResetPytorchVersion resets all changes to the "pytorch_version" field. +func (m *CIWorkflowResultMutation) ResetPytorchVersion() { + m.pytorch_version = nil + delete(m.clearedFields, ciworkflowresult.FieldPytorchVersion) } -// StorageFileID returns the "storage_file" edge ID in the mutation. -func (m *CIWorkflowResultMutation) StorageFileID() (id uuid.UUID, exists bool) { - if m.storage_file != nil { - return *m.storage_file, true - } - return +// SetCudaVersion sets the "cuda_version" field. +func (m *CIWorkflowResultMutation) SetCudaVersion(s string) { + m.cuda_version = &s } -// StorageFileIDs returns the "storage_file" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// StorageFileID instead. It exists only for internal usage by the builders. -func (m *CIWorkflowResultMutation) StorageFileIDs() (ids []uuid.UUID) { - if id := m.storage_file; id != nil { - ids = append(ids, *id) +// CudaVersion returns the value of the "cuda_version" field in the mutation. +func (m *CIWorkflowResultMutation) CudaVersion() (r string, exists bool) { + v := m.cuda_version + if v == nil { + return } - return + return *v, true } -// ResetStorageFile resets all changes to the "storage_file" edge. -func (m *CIWorkflowResultMutation) ResetStorageFile() { - m.storage_file = nil - m.clearedstorage_file = false +// OldCudaVersion returns the old "cuda_version" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldCudaVersion(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCudaVersion is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCudaVersion requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCudaVersion: %w", err) + } + return oldValue.CudaVersion, nil } -// Where appends a list predicates to the CIWorkflowResultMutation builder. -func (m *CIWorkflowResultMutation) Where(ps ...predicate.CIWorkflowResult) { - m.predicates = append(m.predicates, ps...) +// ClearCudaVersion clears the value of the "cuda_version" field. +func (m *CIWorkflowResultMutation) ClearCudaVersion() { + m.cuda_version = nil + m.clearedFields[ciworkflowresult.FieldCudaVersion] = struct{}{} } -// WhereP appends storage-level predicates to the CIWorkflowResultMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *CIWorkflowResultMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.CIWorkflowResult, len(ps)) - for i := range ps { - p[i] = ps[i] - } - m.Where(p...) +// CudaVersionCleared returns if the "cuda_version" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) CudaVersionCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldCudaVersion] + return ok } -// Op returns the operation name. -func (m *CIWorkflowResultMutation) Op() Op { - return m.op +// ResetCudaVersion resets all changes to the "cuda_version" field. +func (m *CIWorkflowResultMutation) ResetCudaVersion() { + m.cuda_version = nil + delete(m.clearedFields, ciworkflowresult.FieldCudaVersion) } -// SetOp allows setting the mutation operation. -func (m *CIWorkflowResultMutation) SetOp(op Op) { - m.op = op +// SetComfyRunFlags sets the "comfy_run_flags" field. +func (m *CIWorkflowResultMutation) SetComfyRunFlags(s string) { + m.comfy_run_flags = &s } -// Type returns the node type of this mutation (CIWorkflowResult). -func (m *CIWorkflowResultMutation) Type() string { - return m.typ +// ComfyRunFlags returns the value of the "comfy_run_flags" field in the mutation. +func (m *CIWorkflowResultMutation) ComfyRunFlags() (r string, exists bool) { + v := m.comfy_run_flags + if v == nil { + return + } + return *v, true } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *CIWorkflowResultMutation) Fields() []string { - fields := make([]string, 0, 10) - if m.create_time != nil { - fields = append(fields, ciworkflowresult.FieldCreateTime) - } - if m.update_time != nil { - fields = append(fields, ciworkflowresult.FieldUpdateTime) - } - if m.operating_system != nil { - fields = append(fields, ciworkflowresult.FieldOperatingSystem) - } - if m.gpu_type != nil { - fields = append(fields, ciworkflowresult.FieldGpuType) +// OldComfyRunFlags returns the old "comfy_run_flags" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldComfyRunFlags(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldComfyRunFlags is only allowed on UpdateOne operations") } - if m.pytorch_version != nil { - fields = append(fields, ciworkflowresult.FieldPytorchVersion) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldComfyRunFlags requires an ID field in the mutation") } - if m.workflow_name != nil { - fields = append(fields, ciworkflowresult.FieldWorkflowName) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldComfyRunFlags: %w", err) } - if m.run_id != nil { - fields = append(fields, ciworkflowresult.FieldRunID) + return oldValue.ComfyRunFlags, nil +} + +// ClearComfyRunFlags clears the value of the "comfy_run_flags" field. +func (m *CIWorkflowResultMutation) ClearComfyRunFlags() { + m.comfy_run_flags = nil + m.clearedFields[ciworkflowresult.FieldComfyRunFlags] = struct{}{} +} + +// ComfyRunFlagsCleared returns if the "comfy_run_flags" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) ComfyRunFlagsCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldComfyRunFlags] + return ok +} + +// ResetComfyRunFlags resets all changes to the "comfy_run_flags" field. +func (m *CIWorkflowResultMutation) ResetComfyRunFlags() { + m.comfy_run_flags = nil + delete(m.clearedFields, ciworkflowresult.FieldComfyRunFlags) +} + +// SetAvgVram sets the "avg_vram" field. +func (m *CIWorkflowResultMutation) SetAvgVram(i int) { + m.avg_vram = &i + m.addavg_vram = nil +} + +// AvgVram returns the value of the "avg_vram" field in the mutation. +func (m *CIWorkflowResultMutation) AvgVram() (r int, exists bool) { + v := m.avg_vram + if v == nil { + return } - if m.status != nil { - fields = append(fields, ciworkflowresult.FieldStatus) + return *v, true +} + +// OldAvgVram returns the old "avg_vram" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldAvgVram(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAvgVram is only allowed on UpdateOne operations") } - if m.start_time != nil { - fields = append(fields, ciworkflowresult.FieldStartTime) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAvgVram requires an ID field in the mutation") } - if m.end_time != nil { - fields = append(fields, ciworkflowresult.FieldEndTime) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAvgVram: %w", err) } - return fields + return oldValue.AvgVram, nil } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *CIWorkflowResultMutation) Field(name string) (ent.Value, bool) { - switch name { - case ciworkflowresult.FieldCreateTime: - return m.CreateTime() - case ciworkflowresult.FieldUpdateTime: - return m.UpdateTime() - case ciworkflowresult.FieldOperatingSystem: - return m.OperatingSystem() - case ciworkflowresult.FieldGpuType: - return m.GpuType() - case ciworkflowresult.FieldPytorchVersion: - return m.PytorchVersion() - case ciworkflowresult.FieldWorkflowName: - return m.WorkflowName() - case ciworkflowresult.FieldRunID: - return m.RunID() - case ciworkflowresult.FieldStatus: - return m.Status() - case ciworkflowresult.FieldStartTime: - return m.StartTime() - case ciworkflowresult.FieldEndTime: - return m.EndTime() +// AddAvgVram adds i to the "avg_vram" field. +func (m *CIWorkflowResultMutation) AddAvgVram(i int) { + if m.addavg_vram != nil { + *m.addavg_vram += i + } else { + m.addavg_vram = &i } - return nil, false } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *CIWorkflowResultMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case ciworkflowresult.FieldCreateTime: - return m.OldCreateTime(ctx) - case ciworkflowresult.FieldUpdateTime: - return m.OldUpdateTime(ctx) - case ciworkflowresult.FieldOperatingSystem: - return m.OldOperatingSystem(ctx) - case ciworkflowresult.FieldGpuType: - return m.OldGpuType(ctx) - case ciworkflowresult.FieldPytorchVersion: - return m.OldPytorchVersion(ctx) - case ciworkflowresult.FieldWorkflowName: - return m.OldWorkflowName(ctx) - case ciworkflowresult.FieldRunID: - return m.OldRunID(ctx) - case ciworkflowresult.FieldStatus: - return m.OldStatus(ctx) - case ciworkflowresult.FieldStartTime: - return m.OldStartTime(ctx) - case ciworkflowresult.FieldEndTime: - return m.OldEndTime(ctx) +// AddedAvgVram returns the value that was added to the "avg_vram" field in this mutation. +func (m *CIWorkflowResultMutation) AddedAvgVram() (r int, exists bool) { + v := m.addavg_vram + if v == nil { + return } - return nil, fmt.Errorf("unknown CIWorkflowResult field %s", name) + return *v, true } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *CIWorkflowResultMutation) SetField(name string, value ent.Value) error { - switch name { - case ciworkflowresult.FieldCreateTime: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreateTime(v) - return nil - case ciworkflowresult.FieldUpdateTime: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdateTime(v) - return nil - case ciworkflowresult.FieldOperatingSystem: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetOperatingSystem(v) - return nil - case ciworkflowresult.FieldGpuType: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetGpuType(v) - return nil - case ciworkflowresult.FieldPytorchVersion: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPytorchVersion(v) - return nil - case ciworkflowresult.FieldWorkflowName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetWorkflowName(v) - return nil - case ciworkflowresult.FieldRunID: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRunID(v) - return nil - case ciworkflowresult.FieldStatus: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetStatus(v) - return nil - case ciworkflowresult.FieldStartTime: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetStartTime(v) - return nil - case ciworkflowresult.FieldEndTime: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetEndTime(v) - return nil - } - return fmt.Errorf("unknown CIWorkflowResult field %s", name) +// ClearAvgVram clears the value of the "avg_vram" field. +func (m *CIWorkflowResultMutation) ClearAvgVram() { + m.avg_vram = nil + m.addavg_vram = nil + m.clearedFields[ciworkflowresult.FieldAvgVram] = struct{}{} } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *CIWorkflowResultMutation) AddedFields() []string { - var fields []string - if m.addstart_time != nil { - fields = append(fields, ciworkflowresult.FieldStartTime) - } - if m.addend_time != nil { - fields = append(fields, ciworkflowresult.FieldEndTime) - } - return fields +// AvgVramCleared returns if the "avg_vram" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) AvgVramCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldAvgVram] + return ok } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *CIWorkflowResultMutation) AddedField(name string) (ent.Value, bool) { - switch name { - case ciworkflowresult.FieldStartTime: - return m.AddedStartTime() - case ciworkflowresult.FieldEndTime: - return m.AddedEndTime() - } - return nil, false +// ResetAvgVram resets all changes to the "avg_vram" field. +func (m *CIWorkflowResultMutation) ResetAvgVram() { + m.avg_vram = nil + m.addavg_vram = nil + delete(m.clearedFields, ciworkflowresult.FieldAvgVram) } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *CIWorkflowResultMutation) AddField(name string, value ent.Value) error { - switch name { - case ciworkflowresult.FieldStartTime: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddStartTime(v) - return nil - case ciworkflowresult.FieldEndTime: - v, ok := value.(int64) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.AddEndTime(v) - return nil - } - return fmt.Errorf("unknown CIWorkflowResult numeric field %s", name) +// SetPeakVram sets the "peak_vram" field. +func (m *CIWorkflowResultMutation) SetPeakVram(i int) { + m.peak_vram = &i + m.addpeak_vram = nil } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *CIWorkflowResultMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(ciworkflowresult.FieldGpuType) { - fields = append(fields, ciworkflowresult.FieldGpuType) - } - if m.FieldCleared(ciworkflowresult.FieldPytorchVersion) { - fields = append(fields, ciworkflowresult.FieldPytorchVersion) - } - if m.FieldCleared(ciworkflowresult.FieldWorkflowName) { - fields = append(fields, ciworkflowresult.FieldWorkflowName) +// PeakVram returns the value of the "peak_vram" field in the mutation. +func (m *CIWorkflowResultMutation) PeakVram() (r int, exists bool) { + v := m.peak_vram + if v == nil { + return } - if m.FieldCleared(ciworkflowresult.FieldRunID) { - fields = append(fields, ciworkflowresult.FieldRunID) + return *v, true +} + +// OldPeakVram returns the old "peak_vram" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldPeakVram(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPeakVram is only allowed on UpdateOne operations") } - if m.FieldCleared(ciworkflowresult.FieldStatus) { - fields = append(fields, ciworkflowresult.FieldStatus) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPeakVram requires an ID field in the mutation") } - if m.FieldCleared(ciworkflowresult.FieldStartTime) { - fields = append(fields, ciworkflowresult.FieldStartTime) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPeakVram: %w", err) } - if m.FieldCleared(ciworkflowresult.FieldEndTime) { - fields = append(fields, ciworkflowresult.FieldEndTime) + return oldValue.PeakVram, nil +} + +// AddPeakVram adds i to the "peak_vram" field. +func (m *CIWorkflowResultMutation) AddPeakVram(i int) { + if m.addpeak_vram != nil { + *m.addpeak_vram += i + } else { + m.addpeak_vram = &i } - return fields } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *CIWorkflowResultMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] - return ok -} - -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *CIWorkflowResultMutation) ClearField(name string) error { - switch name { - case ciworkflowresult.FieldGpuType: - m.ClearGpuType() - return nil - case ciworkflowresult.FieldPytorchVersion: - m.ClearPytorchVersion() - return nil - case ciworkflowresult.FieldWorkflowName: - m.ClearWorkflowName() - return nil - case ciworkflowresult.FieldRunID: - m.ClearRunID() - return nil - case ciworkflowresult.FieldStatus: - m.ClearStatus() - return nil - case ciworkflowresult.FieldStartTime: - m.ClearStartTime() - return nil - case ciworkflowresult.FieldEndTime: - m.ClearEndTime() - return nil +// AddedPeakVram returns the value that was added to the "peak_vram" field in this mutation. +func (m *CIWorkflowResultMutation) AddedPeakVram() (r int, exists bool) { + v := m.addpeak_vram + if v == nil { + return } - return fmt.Errorf("unknown CIWorkflowResult nullable field %s", name) + return *v, true } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *CIWorkflowResultMutation) ResetField(name string) error { - switch name { - case ciworkflowresult.FieldCreateTime: - m.ResetCreateTime() - return nil - case ciworkflowresult.FieldUpdateTime: - m.ResetUpdateTime() - return nil - case ciworkflowresult.FieldOperatingSystem: - m.ResetOperatingSystem() - return nil - case ciworkflowresult.FieldGpuType: - m.ResetGpuType() - return nil - case ciworkflowresult.FieldPytorchVersion: - m.ResetPytorchVersion() - return nil - case ciworkflowresult.FieldWorkflowName: - m.ResetWorkflowName() - return nil - case ciworkflowresult.FieldRunID: - m.ResetRunID() - return nil - case ciworkflowresult.FieldStatus: - m.ResetStatus() - return nil - case ciworkflowresult.FieldStartTime: - m.ResetStartTime() - return nil - case ciworkflowresult.FieldEndTime: - m.ResetEndTime() - return nil - } - return fmt.Errorf("unknown CIWorkflowResult field %s", name) +// ClearPeakVram clears the value of the "peak_vram" field. +func (m *CIWorkflowResultMutation) ClearPeakVram() { + m.peak_vram = nil + m.addpeak_vram = nil + m.clearedFields[ciworkflowresult.FieldPeakVram] = struct{}{} } -// AddedEdges returns all edge names that were set/added in this mutation. -func (m *CIWorkflowResultMutation) AddedEdges() []string { - edges := make([]string, 0, 2) - if m.gitcommit != nil { - edges = append(edges, ciworkflowresult.EdgeGitcommit) - } - if m.storage_file != nil { - edges = append(edges, ciworkflowresult.EdgeStorageFile) - } - return edges +// PeakVramCleared returns if the "peak_vram" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) PeakVramCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldPeakVram] + return ok } -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *CIWorkflowResultMutation) AddedIDs(name string) []ent.Value { - switch name { - case ciworkflowresult.EdgeGitcommit: - if id := m.gitcommit; id != nil { - return []ent.Value{*id} - } - case ciworkflowresult.EdgeStorageFile: - if id := m.storage_file; id != nil { - return []ent.Value{*id} - } - } - return nil +// ResetPeakVram resets all changes to the "peak_vram" field. +func (m *CIWorkflowResultMutation) ResetPeakVram() { + m.peak_vram = nil + m.addpeak_vram = nil + delete(m.clearedFields, ciworkflowresult.FieldPeakVram) } -// RemovedEdges returns all edge names that were removed in this mutation. -func (m *CIWorkflowResultMutation) RemovedEdges() []string { - edges := make([]string, 0, 2) - return edges +// SetJobTriggerUser sets the "job_trigger_user" field. +func (m *CIWorkflowResultMutation) SetJobTriggerUser(s string) { + m.job_trigger_user = &s } -// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with -// the given name in this mutation. -func (m *CIWorkflowResultMutation) RemovedIDs(name string) []ent.Value { - return nil +// JobTriggerUser returns the value of the "job_trigger_user" field in the mutation. +func (m *CIWorkflowResultMutation) JobTriggerUser() (r string, exists bool) { + v := m.job_trigger_user + if v == nil { + return + } + return *v, true } -// ClearedEdges returns all edge names that were cleared in this mutation. -func (m *CIWorkflowResultMutation) ClearedEdges() []string { - edges := make([]string, 0, 2) - if m.clearedgitcommit { - edges = append(edges, ciworkflowresult.EdgeGitcommit) +// OldJobTriggerUser returns the old "job_trigger_user" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldJobTriggerUser(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldJobTriggerUser is only allowed on UpdateOne operations") } - if m.clearedstorage_file { - edges = append(edges, ciworkflowresult.EdgeStorageFile) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldJobTriggerUser requires an ID field in the mutation") } - return edges -} - -// EdgeCleared returns a boolean which indicates if the edge with the given name -// was cleared in this mutation. -func (m *CIWorkflowResultMutation) EdgeCleared(name string) bool { - switch name { - case ciworkflowresult.EdgeGitcommit: - return m.clearedgitcommit - case ciworkflowresult.EdgeStorageFile: - return m.clearedstorage_file + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldJobTriggerUser: %w", err) } - return false + return oldValue.JobTriggerUser, nil } -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *CIWorkflowResultMutation) ClearEdge(name string) error { - switch name { - case ciworkflowresult.EdgeGitcommit: - m.ClearGitcommit() - return nil - case ciworkflowresult.EdgeStorageFile: - m.ClearStorageFile() - return nil - } - return fmt.Errorf("unknown CIWorkflowResult unique edge %s", name) +// ClearJobTriggerUser clears the value of the "job_trigger_user" field. +func (m *CIWorkflowResultMutation) ClearJobTriggerUser() { + m.job_trigger_user = nil + m.clearedFields[ciworkflowresult.FieldJobTriggerUser] = struct{}{} } -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *CIWorkflowResultMutation) ResetEdge(name string) error { - switch name { - case ciworkflowresult.EdgeGitcommit: - m.ResetGitcommit() - return nil - case ciworkflowresult.EdgeStorageFile: - m.ResetStorageFile() - return nil - } - return fmt.Errorf("unknown CIWorkflowResult edge %s", name) +// JobTriggerUserCleared returns if the "job_trigger_user" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) JobTriggerUserCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldJobTriggerUser] + return ok } -// GitCommitMutation represents an operation that mutates the GitCommit nodes in the graph. -type GitCommitMutation struct { - config - op Op - typ string - id *uuid.UUID - create_time *time.Time - update_time *time.Time - commit_hash *string - branch_name *string - repo_name *string - commit_message *string - commit_timestamp *time.Time - author *string - timestamp *time.Time - clearedFields map[string]struct{} - results map[uuid.UUID]struct{} - removedresults map[uuid.UUID]struct{} - clearedresults bool - done bool - oldValue func(context.Context) (*GitCommit, error) - predicates []predicate.GitCommit +// ResetJobTriggerUser resets all changes to the "job_trigger_user" field. +func (m *CIWorkflowResultMutation) ResetJobTriggerUser() { + m.job_trigger_user = nil + delete(m.clearedFields, ciworkflowresult.FieldJobTriggerUser) } -var _ ent.Mutation = (*GitCommitMutation)(nil) +// SetMetadata sets the "metadata" field. +func (m *CIWorkflowResultMutation) SetMetadata(value map[string]interface{}) { + m.metadata = &value +} -// gitcommitOption allows management of the mutation configuration using functional options. -type gitcommitOption func(*GitCommitMutation) +// Metadata returns the value of the "metadata" field in the mutation. +func (m *CIWorkflowResultMutation) Metadata() (r map[string]interface{}, exists bool) { + v := m.metadata + if v == nil { + return + } + return *v, true +} -// newGitCommitMutation creates new mutation for the GitCommit entity. -func newGitCommitMutation(c config, op Op, opts ...gitcommitOption) *GitCommitMutation { - m := &GitCommitMutation{ - config: c, - op: op, - typ: TypeGitCommit, - clearedFields: make(map[string]struct{}), +// OldMetadata returns the old "metadata" field's value of the CIWorkflowResult entity. +// If the CIWorkflowResult object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *CIWorkflowResultMutation) OldMetadata(ctx context.Context) (v map[string]interface{}, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMetadata is only allowed on UpdateOne operations") } - for _, opt := range opts { - opt(m) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMetadata requires an ID field in the mutation") } - return m + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMetadata: %w", err) + } + return oldValue.Metadata, nil } -// withGitCommitID sets the ID field of the mutation. -func withGitCommitID(id uuid.UUID) gitcommitOption { - return func(m *GitCommitMutation) { - var ( - err error - once sync.Once - value *GitCommit - ) - m.oldValue = func(ctx context.Context) (*GitCommit, error) { - once.Do(func() { - if m.done { - err = errors.New("querying old values post mutation is not allowed") - } else { - value, err = m.Client().GitCommit.Get(ctx, id) - } - }) - return value, err - } - m.id = &id - } +// ClearMetadata clears the value of the "metadata" field. +func (m *CIWorkflowResultMutation) ClearMetadata() { + m.metadata = nil + m.clearedFields[ciworkflowresult.FieldMetadata] = struct{}{} } -// withGitCommit sets the old GitCommit of the mutation. -func withGitCommit(node *GitCommit) gitcommitOption { - return func(m *GitCommitMutation) { - m.oldValue = func(context.Context) (*GitCommit, error) { - return node, nil - } - m.id = &node.ID - } +// MetadataCleared returns if the "metadata" field was cleared in this mutation. +func (m *CIWorkflowResultMutation) MetadataCleared() bool { + _, ok := m.clearedFields[ciworkflowresult.FieldMetadata] + return ok } -// Client returns a new `ent.Client` from the mutation. If the mutation was -// executed in a transaction (ent.Tx), a transactional client is returned. -func (m GitCommitMutation) Client() *Client { - client := &Client{config: m.config} - client.init() - return client +// ResetMetadata resets all changes to the "metadata" field. +func (m *CIWorkflowResultMutation) ResetMetadata() { + m.metadata = nil + delete(m.clearedFields, ciworkflowresult.FieldMetadata) } -// Tx returns an `ent.Tx` for mutations that were executed in transactions; -// it returns an error otherwise. -func (m GitCommitMutation) Tx() (*Tx, error) { - if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") - } - tx := &Tx{config: m.config} - tx.init() - return tx, nil +// SetGitcommitID sets the "gitcommit" edge to the GitCommit entity by id. +func (m *CIWorkflowResultMutation) SetGitcommitID(id uuid.UUID) { + m.gitcommit = &id } -// SetID sets the value of the id field. Note that this -// operation is only accepted on creation of GitCommit entities. -func (m *GitCommitMutation) SetID(id uuid.UUID) { - m.id = &id +// ClearGitcommit clears the "gitcommit" edge to the GitCommit entity. +func (m *CIWorkflowResultMutation) ClearGitcommit() { + m.clearedgitcommit = true } -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *GitCommitMutation) ID() (id uuid.UUID, exists bool) { - if m.id == nil { - return - } - return *m.id, true +// GitcommitCleared reports if the "gitcommit" edge to the GitCommit entity was cleared. +func (m *CIWorkflowResultMutation) GitcommitCleared() bool { + return m.clearedgitcommit } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *GitCommitMutation) IDs(ctx context.Context) ([]uuid.UUID, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []uuid.UUID{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().GitCommit.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) +// GitcommitID returns the "gitcommit" edge ID in the mutation. +func (m *CIWorkflowResultMutation) GitcommitID() (id uuid.UUID, exists bool) { + if m.gitcommit != nil { + return *m.gitcommit, true } + return } -// SetCreateTime sets the "create_time" field. -func (m *GitCommitMutation) SetCreateTime(t time.Time) { - m.create_time = &t +// GitcommitIDs returns the "gitcommit" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// GitcommitID instead. It exists only for internal usage by the builders. +func (m *CIWorkflowResultMutation) GitcommitIDs() (ids []uuid.UUID) { + if id := m.gitcommit; id != nil { + ids = append(ids, *id) + } + return } -// CreateTime returns the value of the "create_time" field in the mutation. -func (m *GitCommitMutation) CreateTime() (r time.Time, exists bool) { - v := m.create_time - if v == nil { - return - } - return *v, true +// ResetGitcommit resets all changes to the "gitcommit" edge. +func (m *CIWorkflowResultMutation) ResetGitcommit() { + m.gitcommit = nil + m.clearedgitcommit = false } -// OldCreateTime returns the old "create_time" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldCreateTime(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreateTime is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreateTime requires an ID field in the mutation") +// AddStorageFileIDs adds the "storage_file" edge to the StorageFile entity by ids. +func (m *CIWorkflowResultMutation) AddStorageFileIDs(ids ...uuid.UUID) { + if m.storage_file == nil { + m.storage_file = make(map[uuid.UUID]struct{}) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCreateTime: %w", err) + for i := range ids { + m.storage_file[ids[i]] = struct{}{} } - return oldValue.CreateTime, nil -} - -// ResetCreateTime resets all changes to the "create_time" field. -func (m *GitCommitMutation) ResetCreateTime() { - m.create_time = nil } -// SetUpdateTime sets the "update_time" field. -func (m *GitCommitMutation) SetUpdateTime(t time.Time) { - m.update_time = &t +// ClearStorageFile clears the "storage_file" edge to the StorageFile entity. +func (m *CIWorkflowResultMutation) ClearStorageFile() { + m.clearedstorage_file = true } -// UpdateTime returns the value of the "update_time" field in the mutation. -func (m *GitCommitMutation) UpdateTime() (r time.Time, exists bool) { - v := m.update_time - if v == nil { - return - } - return *v, true +// StorageFileCleared reports if the "storage_file" edge to the StorageFile entity was cleared. +func (m *CIWorkflowResultMutation) StorageFileCleared() bool { + return m.clearedstorage_file } -// OldUpdateTime returns the old "update_time" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldUpdateTime(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdateTime is only allowed on UpdateOne operations") +// RemoveStorageFileIDs removes the "storage_file" edge to the StorageFile entity by IDs. +func (m *CIWorkflowResultMutation) RemoveStorageFileIDs(ids ...uuid.UUID) { + if m.removedstorage_file == nil { + m.removedstorage_file = make(map[uuid.UUID]struct{}) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdateTime requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUpdateTime: %w", err) + for i := range ids { + delete(m.storage_file, ids[i]) + m.removedstorage_file[ids[i]] = struct{}{} } - return oldValue.UpdateTime, nil -} - -// ResetUpdateTime resets all changes to the "update_time" field. -func (m *GitCommitMutation) ResetUpdateTime() { - m.update_time = nil -} - -// SetCommitHash sets the "commit_hash" field. -func (m *GitCommitMutation) SetCommitHash(s string) { - m.commit_hash = &s } -// CommitHash returns the value of the "commit_hash" field in the mutation. -func (m *GitCommitMutation) CommitHash() (r string, exists bool) { - v := m.commit_hash - if v == nil { - return +// RemovedStorageFile returns the removed IDs of the "storage_file" edge to the StorageFile entity. +func (m *CIWorkflowResultMutation) RemovedStorageFileIDs() (ids []uuid.UUID) { + for id := range m.removedstorage_file { + ids = append(ids, id) } - return *v, true + return } -// OldCommitHash returns the old "commit_hash" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldCommitHash(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCommitHash is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCommitHash requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCommitHash: %w", err) +// StorageFileIDs returns the "storage_file" edge IDs in the mutation. +func (m *CIWorkflowResultMutation) StorageFileIDs() (ids []uuid.UUID) { + for id := range m.storage_file { + ids = append(ids, id) } - return oldValue.CommitHash, nil + return } -// ResetCommitHash resets all changes to the "commit_hash" field. -func (m *GitCommitMutation) ResetCommitHash() { - m.commit_hash = nil +// ResetStorageFile resets all changes to the "storage_file" edge. +func (m *CIWorkflowResultMutation) ResetStorageFile() { + m.storage_file = nil + m.clearedstorage_file = false + m.removedstorage_file = nil } -// SetBranchName sets the "branch_name" field. -func (m *GitCommitMutation) SetBranchName(s string) { - m.branch_name = &s +// Where appends a list predicates to the CIWorkflowResultMutation builder. +func (m *CIWorkflowResultMutation) Where(ps ...predicate.CIWorkflowResult) { + m.predicates = append(m.predicates, ps...) } -// BranchName returns the value of the "branch_name" field in the mutation. -func (m *GitCommitMutation) BranchName() (r string, exists bool) { - v := m.branch_name - if v == nil { - return +// WhereP appends storage-level predicates to the CIWorkflowResultMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *CIWorkflowResultMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.CIWorkflowResult, len(ps)) + for i := range ps { + p[i] = ps[i] } - return *v, true + m.Where(p...) } -// OldBranchName returns the old "branch_name" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldBranchName(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldBranchName is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldBranchName requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldBranchName: %w", err) - } - return oldValue.BranchName, nil +// Op returns the operation name. +func (m *CIWorkflowResultMutation) Op() Op { + return m.op } -// ResetBranchName resets all changes to the "branch_name" field. -func (m *GitCommitMutation) ResetBranchName() { - m.branch_name = nil +// SetOp allows setting the mutation operation. +func (m *CIWorkflowResultMutation) SetOp(op Op) { + m.op = op } -// SetRepoName sets the "repo_name" field. -func (m *GitCommitMutation) SetRepoName(s string) { - m.repo_name = &s +// Type returns the node type of this mutation (CIWorkflowResult). +func (m *CIWorkflowResultMutation) Type() string { + return m.typ } -// RepoName returns the value of the "repo_name" field in the mutation. -func (m *GitCommitMutation) RepoName() (r string, exists bool) { - v := m.repo_name - if v == nil { - return +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *CIWorkflowResultMutation) Fields() []string { + fields := make([]string, 0, 17) + if m.create_time != nil { + fields = append(fields, ciworkflowresult.FieldCreateTime) } - return *v, true -} - -// OldRepoName returns the old "repo_name" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldRepoName(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRepoName is only allowed on UpdateOne operations") + if m.update_time != nil { + fields = append(fields, ciworkflowresult.FieldUpdateTime) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRepoName requires an ID field in the mutation") + if m.operating_system != nil { + fields = append(fields, ciworkflowresult.FieldOperatingSystem) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldRepoName: %w", err) + if m.workflow_name != nil { + fields = append(fields, ciworkflowresult.FieldWorkflowName) } - return oldValue.RepoName, nil -} - -// ResetRepoName resets all changes to the "repo_name" field. -func (m *GitCommitMutation) ResetRepoName() { - m.repo_name = nil -} - -// SetCommitMessage sets the "commit_message" field. -func (m *GitCommitMutation) SetCommitMessage(s string) { - m.commit_message = &s -} - -// CommitMessage returns the value of the "commit_message" field in the mutation. -func (m *GitCommitMutation) CommitMessage() (r string, exists bool) { - v := m.commit_message - if v == nil { - return + if m.run_id != nil { + fields = append(fields, ciworkflowresult.FieldRunID) } - return *v, true -} - -// OldCommitMessage returns the old "commit_message" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldCommitMessage(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCommitMessage is only allowed on UpdateOne operations") + if m.job_id != nil { + fields = append(fields, ciworkflowresult.FieldJobID) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCommitMessage requires an ID field in the mutation") + if m.status != nil { + fields = append(fields, ciworkflowresult.FieldStatus) + } + if m.start_time != nil { + fields = append(fields, ciworkflowresult.FieldStartTime) + } + if m.end_time != nil { + fields = append(fields, ciworkflowresult.FieldEndTime) + } + if m.python_version != nil { + fields = append(fields, ciworkflowresult.FieldPythonVersion) + } + if m.pytorch_version != nil { + fields = append(fields, ciworkflowresult.FieldPytorchVersion) + } + if m.cuda_version != nil { + fields = append(fields, ciworkflowresult.FieldCudaVersion) + } + if m.comfy_run_flags != nil { + fields = append(fields, ciworkflowresult.FieldComfyRunFlags) + } + if m.avg_vram != nil { + fields = append(fields, ciworkflowresult.FieldAvgVram) + } + if m.peak_vram != nil { + fields = append(fields, ciworkflowresult.FieldPeakVram) + } + if m.job_trigger_user != nil { + fields = append(fields, ciworkflowresult.FieldJobTriggerUser) + } + if m.metadata != nil { + fields = append(fields, ciworkflowresult.FieldMetadata) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *CIWorkflowResultMutation) Field(name string) (ent.Value, bool) { + switch name { + case ciworkflowresult.FieldCreateTime: + return m.CreateTime() + case ciworkflowresult.FieldUpdateTime: + return m.UpdateTime() + case ciworkflowresult.FieldOperatingSystem: + return m.OperatingSystem() + case ciworkflowresult.FieldWorkflowName: + return m.WorkflowName() + case ciworkflowresult.FieldRunID: + return m.RunID() + case ciworkflowresult.FieldJobID: + return m.JobID() + case ciworkflowresult.FieldStatus: + return m.Status() + case ciworkflowresult.FieldStartTime: + return m.StartTime() + case ciworkflowresult.FieldEndTime: + return m.EndTime() + case ciworkflowresult.FieldPythonVersion: + return m.PythonVersion() + case ciworkflowresult.FieldPytorchVersion: + return m.PytorchVersion() + case ciworkflowresult.FieldCudaVersion: + return m.CudaVersion() + case ciworkflowresult.FieldComfyRunFlags: + return m.ComfyRunFlags() + case ciworkflowresult.FieldAvgVram: + return m.AvgVram() + case ciworkflowresult.FieldPeakVram: + return m.PeakVram() + case ciworkflowresult.FieldJobTriggerUser: + return m.JobTriggerUser() + case ciworkflowresult.FieldMetadata: + return m.Metadata() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *CIWorkflowResultMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case ciworkflowresult.FieldCreateTime: + return m.OldCreateTime(ctx) + case ciworkflowresult.FieldUpdateTime: + return m.OldUpdateTime(ctx) + case ciworkflowresult.FieldOperatingSystem: + return m.OldOperatingSystem(ctx) + case ciworkflowresult.FieldWorkflowName: + return m.OldWorkflowName(ctx) + case ciworkflowresult.FieldRunID: + return m.OldRunID(ctx) + case ciworkflowresult.FieldJobID: + return m.OldJobID(ctx) + case ciworkflowresult.FieldStatus: + return m.OldStatus(ctx) + case ciworkflowresult.FieldStartTime: + return m.OldStartTime(ctx) + case ciworkflowresult.FieldEndTime: + return m.OldEndTime(ctx) + case ciworkflowresult.FieldPythonVersion: + return m.OldPythonVersion(ctx) + case ciworkflowresult.FieldPytorchVersion: + return m.OldPytorchVersion(ctx) + case ciworkflowresult.FieldCudaVersion: + return m.OldCudaVersion(ctx) + case ciworkflowresult.FieldComfyRunFlags: + return m.OldComfyRunFlags(ctx) + case ciworkflowresult.FieldAvgVram: + return m.OldAvgVram(ctx) + case ciworkflowresult.FieldPeakVram: + return m.OldPeakVram(ctx) + case ciworkflowresult.FieldJobTriggerUser: + return m.OldJobTriggerUser(ctx) + case ciworkflowresult.FieldMetadata: + return m.OldMetadata(ctx) + } + return nil, fmt.Errorf("unknown CIWorkflowResult field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *CIWorkflowResultMutation) SetField(name string, value ent.Value) error { + switch name { + case ciworkflowresult.FieldCreateTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreateTime(v) + return nil + case ciworkflowresult.FieldUpdateTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdateTime(v) + return nil + case ciworkflowresult.FieldOperatingSystem: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetOperatingSystem(v) + return nil + case ciworkflowresult.FieldWorkflowName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetWorkflowName(v) + return nil + case ciworkflowresult.FieldRunID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRunID(v) + return nil + case ciworkflowresult.FieldJobID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetJobID(v) + return nil + case ciworkflowresult.FieldStatus: + v, ok := value.(schema.WorkflowRunStatusType) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case ciworkflowresult.FieldStartTime: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStartTime(v) + return nil + case ciworkflowresult.FieldEndTime: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEndTime(v) + return nil + case ciworkflowresult.FieldPythonVersion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPythonVersion(v) + return nil + case ciworkflowresult.FieldPytorchVersion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPytorchVersion(v) + return nil + case ciworkflowresult.FieldCudaVersion: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCudaVersion(v) + return nil + case ciworkflowresult.FieldComfyRunFlags: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetComfyRunFlags(v) + return nil + case ciworkflowresult.FieldAvgVram: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAvgVram(v) + return nil + case ciworkflowresult.FieldPeakVram: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPeakVram(v) + return nil + case ciworkflowresult.FieldJobTriggerUser: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetJobTriggerUser(v) + return nil + case ciworkflowresult.FieldMetadata: + v, ok := value.(map[string]interface{}) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMetadata(v) + return nil + } + return fmt.Errorf("unknown CIWorkflowResult field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *CIWorkflowResultMutation) AddedFields() []string { + var fields []string + if m.addstart_time != nil { + fields = append(fields, ciworkflowresult.FieldStartTime) + } + if m.addend_time != nil { + fields = append(fields, ciworkflowresult.FieldEndTime) + } + if m.addavg_vram != nil { + fields = append(fields, ciworkflowresult.FieldAvgVram) + } + if m.addpeak_vram != nil { + fields = append(fields, ciworkflowresult.FieldPeakVram) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *CIWorkflowResultMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case ciworkflowresult.FieldStartTime: + return m.AddedStartTime() + case ciworkflowresult.FieldEndTime: + return m.AddedEndTime() + case ciworkflowresult.FieldAvgVram: + return m.AddedAvgVram() + case ciworkflowresult.FieldPeakVram: + return m.AddedPeakVram() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *CIWorkflowResultMutation) AddField(name string, value ent.Value) error { + switch name { + case ciworkflowresult.FieldStartTime: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddStartTime(v) + return nil + case ciworkflowresult.FieldEndTime: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddEndTime(v) + return nil + case ciworkflowresult.FieldAvgVram: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddAvgVram(v) + return nil + case ciworkflowresult.FieldPeakVram: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPeakVram(v) + return nil + } + return fmt.Errorf("unknown CIWorkflowResult numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *CIWorkflowResultMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(ciworkflowresult.FieldWorkflowName) { + fields = append(fields, ciworkflowresult.FieldWorkflowName) + } + if m.FieldCleared(ciworkflowresult.FieldRunID) { + fields = append(fields, ciworkflowresult.FieldRunID) + } + if m.FieldCleared(ciworkflowresult.FieldJobID) { + fields = append(fields, ciworkflowresult.FieldJobID) + } + if m.FieldCleared(ciworkflowresult.FieldStartTime) { + fields = append(fields, ciworkflowresult.FieldStartTime) + } + if m.FieldCleared(ciworkflowresult.FieldEndTime) { + fields = append(fields, ciworkflowresult.FieldEndTime) + } + if m.FieldCleared(ciworkflowresult.FieldPythonVersion) { + fields = append(fields, ciworkflowresult.FieldPythonVersion) + } + if m.FieldCleared(ciworkflowresult.FieldPytorchVersion) { + fields = append(fields, ciworkflowresult.FieldPytorchVersion) + } + if m.FieldCleared(ciworkflowresult.FieldCudaVersion) { + fields = append(fields, ciworkflowresult.FieldCudaVersion) + } + if m.FieldCleared(ciworkflowresult.FieldComfyRunFlags) { + fields = append(fields, ciworkflowresult.FieldComfyRunFlags) + } + if m.FieldCleared(ciworkflowresult.FieldAvgVram) { + fields = append(fields, ciworkflowresult.FieldAvgVram) + } + if m.FieldCleared(ciworkflowresult.FieldPeakVram) { + fields = append(fields, ciworkflowresult.FieldPeakVram) + } + if m.FieldCleared(ciworkflowresult.FieldJobTriggerUser) { + fields = append(fields, ciworkflowresult.FieldJobTriggerUser) + } + if m.FieldCleared(ciworkflowresult.FieldMetadata) { + fields = append(fields, ciworkflowresult.FieldMetadata) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *CIWorkflowResultMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *CIWorkflowResultMutation) ClearField(name string) error { + switch name { + case ciworkflowresult.FieldWorkflowName: + m.ClearWorkflowName() + return nil + case ciworkflowresult.FieldRunID: + m.ClearRunID() + return nil + case ciworkflowresult.FieldJobID: + m.ClearJobID() + return nil + case ciworkflowresult.FieldStartTime: + m.ClearStartTime() + return nil + case ciworkflowresult.FieldEndTime: + m.ClearEndTime() + return nil + case ciworkflowresult.FieldPythonVersion: + m.ClearPythonVersion() + return nil + case ciworkflowresult.FieldPytorchVersion: + m.ClearPytorchVersion() + return nil + case ciworkflowresult.FieldCudaVersion: + m.ClearCudaVersion() + return nil + case ciworkflowresult.FieldComfyRunFlags: + m.ClearComfyRunFlags() + return nil + case ciworkflowresult.FieldAvgVram: + m.ClearAvgVram() + return nil + case ciworkflowresult.FieldPeakVram: + m.ClearPeakVram() + return nil + case ciworkflowresult.FieldJobTriggerUser: + m.ClearJobTriggerUser() + return nil + case ciworkflowresult.FieldMetadata: + m.ClearMetadata() + return nil + } + return fmt.Errorf("unknown CIWorkflowResult nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *CIWorkflowResultMutation) ResetField(name string) error { + switch name { + case ciworkflowresult.FieldCreateTime: + m.ResetCreateTime() + return nil + case ciworkflowresult.FieldUpdateTime: + m.ResetUpdateTime() + return nil + case ciworkflowresult.FieldOperatingSystem: + m.ResetOperatingSystem() + return nil + case ciworkflowresult.FieldWorkflowName: + m.ResetWorkflowName() + return nil + case ciworkflowresult.FieldRunID: + m.ResetRunID() + return nil + case ciworkflowresult.FieldJobID: + m.ResetJobID() + return nil + case ciworkflowresult.FieldStatus: + m.ResetStatus() + return nil + case ciworkflowresult.FieldStartTime: + m.ResetStartTime() + return nil + case ciworkflowresult.FieldEndTime: + m.ResetEndTime() + return nil + case ciworkflowresult.FieldPythonVersion: + m.ResetPythonVersion() + return nil + case ciworkflowresult.FieldPytorchVersion: + m.ResetPytorchVersion() + return nil + case ciworkflowresult.FieldCudaVersion: + m.ResetCudaVersion() + return nil + case ciworkflowresult.FieldComfyRunFlags: + m.ResetComfyRunFlags() + return nil + case ciworkflowresult.FieldAvgVram: + m.ResetAvgVram() + return nil + case ciworkflowresult.FieldPeakVram: + m.ResetPeakVram() + return nil + case ciworkflowresult.FieldJobTriggerUser: + m.ResetJobTriggerUser() + return nil + case ciworkflowresult.FieldMetadata: + m.ResetMetadata() + return nil + } + return fmt.Errorf("unknown CIWorkflowResult field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *CIWorkflowResultMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.gitcommit != nil { + edges = append(edges, ciworkflowresult.EdgeGitcommit) + } + if m.storage_file != nil { + edges = append(edges, ciworkflowresult.EdgeStorageFile) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *CIWorkflowResultMutation) AddedIDs(name string) []ent.Value { + switch name { + case ciworkflowresult.EdgeGitcommit: + if id := m.gitcommit; id != nil { + return []ent.Value{*id} + } + case ciworkflowresult.EdgeStorageFile: + ids := make([]ent.Value, 0, len(m.storage_file)) + for id := range m.storage_file { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *CIWorkflowResultMutation) RemovedEdges() []string { + edges := make([]string, 0, 2) + if m.removedstorage_file != nil { + edges = append(edges, ciworkflowresult.EdgeStorageFile) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *CIWorkflowResultMutation) RemovedIDs(name string) []ent.Value { + switch name { + case ciworkflowresult.EdgeStorageFile: + ids := make([]ent.Value, 0, len(m.removedstorage_file)) + for id := range m.removedstorage_file { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *CIWorkflowResultMutation) ClearedEdges() []string { + edges := make([]string, 0, 2) + if m.clearedgitcommit { + edges = append(edges, ciworkflowresult.EdgeGitcommit) + } + if m.clearedstorage_file { + edges = append(edges, ciworkflowresult.EdgeStorageFile) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *CIWorkflowResultMutation) EdgeCleared(name string) bool { + switch name { + case ciworkflowresult.EdgeGitcommit: + return m.clearedgitcommit + case ciworkflowresult.EdgeStorageFile: + return m.clearedstorage_file + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *CIWorkflowResultMutation) ClearEdge(name string) error { + switch name { + case ciworkflowresult.EdgeGitcommit: + m.ClearGitcommit() + return nil + } + return fmt.Errorf("unknown CIWorkflowResult unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *CIWorkflowResultMutation) ResetEdge(name string) error { + switch name { + case ciworkflowresult.EdgeGitcommit: + m.ResetGitcommit() + return nil + case ciworkflowresult.EdgeStorageFile: + m.ResetStorageFile() + return nil + } + return fmt.Errorf("unknown CIWorkflowResult edge %s", name) +} + +// GitCommitMutation represents an operation that mutates the GitCommit nodes in the graph. +type GitCommitMutation struct { + config + op Op + typ string + id *uuid.UUID + create_time *time.Time + update_time *time.Time + commit_hash *string + branch_name *string + repo_name *string + commit_message *string + commit_timestamp *time.Time + author *string + timestamp *time.Time + pr_number *string + clearedFields map[string]struct{} + results map[uuid.UUID]struct{} + removedresults map[uuid.UUID]struct{} + clearedresults bool + done bool + oldValue func(context.Context) (*GitCommit, error) + predicates []predicate.GitCommit +} + +var _ ent.Mutation = (*GitCommitMutation)(nil) + +// gitcommitOption allows management of the mutation configuration using functional options. +type gitcommitOption func(*GitCommitMutation) + +// newGitCommitMutation creates new mutation for the GitCommit entity. +func newGitCommitMutation(c config, op Op, opts ...gitcommitOption) *GitCommitMutation { + m := &GitCommitMutation{ + config: c, + op: op, + typ: TypeGitCommit, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withGitCommitID sets the ID field of the mutation. +func withGitCommitID(id uuid.UUID) gitcommitOption { + return func(m *GitCommitMutation) { + var ( + err error + once sync.Once + value *GitCommit + ) + m.oldValue = func(ctx context.Context) (*GitCommit, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().GitCommit.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withGitCommit sets the old GitCommit of the mutation. +func withGitCommit(node *GitCommit) gitcommitOption { + return func(m *GitCommitMutation) { + m.oldValue = func(context.Context) (*GitCommit, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m GitCommitMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m GitCommitMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of GitCommit entities. +func (m *GitCommitMutation) SetID(id uuid.UUID) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *GitCommitMutation) ID() (id uuid.UUID, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *GitCommitMutation) IDs(ctx context.Context) ([]uuid.UUID, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []uuid.UUID{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().GitCommit.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreateTime sets the "create_time" field. +func (m *GitCommitMutation) SetCreateTime(t time.Time) { + m.create_time = &t +} + +// CreateTime returns the value of the "create_time" field in the mutation. +func (m *GitCommitMutation) CreateTime() (r time.Time, exists bool) { + v := m.create_time + if v == nil { + return + } + return *v, true +} + +// OldCreateTime returns the old "create_time" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldCreateTime(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreateTime is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreateTime requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreateTime: %w", err) + } + return oldValue.CreateTime, nil +} + +// ResetCreateTime resets all changes to the "create_time" field. +func (m *GitCommitMutation) ResetCreateTime() { + m.create_time = nil +} + +// SetUpdateTime sets the "update_time" field. +func (m *GitCommitMutation) SetUpdateTime(t time.Time) { + m.update_time = &t +} + +// UpdateTime returns the value of the "update_time" field in the mutation. +func (m *GitCommitMutation) UpdateTime() (r time.Time, exists bool) { + v := m.update_time + if v == nil { + return + } + return *v, true +} + +// OldUpdateTime returns the old "update_time" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldUpdateTime(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdateTime is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdateTime requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdateTime: %w", err) + } + return oldValue.UpdateTime, nil +} + +// ResetUpdateTime resets all changes to the "update_time" field. +func (m *GitCommitMutation) ResetUpdateTime() { + m.update_time = nil +} + +// SetCommitHash sets the "commit_hash" field. +func (m *GitCommitMutation) SetCommitHash(s string) { + m.commit_hash = &s +} + +// CommitHash returns the value of the "commit_hash" field in the mutation. +func (m *GitCommitMutation) CommitHash() (r string, exists bool) { + v := m.commit_hash + if v == nil { + return + } + return *v, true +} + +// OldCommitHash returns the old "commit_hash" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldCommitHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCommitHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCommitHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCommitHash: %w", err) + } + return oldValue.CommitHash, nil +} + +// ResetCommitHash resets all changes to the "commit_hash" field. +func (m *GitCommitMutation) ResetCommitHash() { + m.commit_hash = nil +} + +// SetBranchName sets the "branch_name" field. +func (m *GitCommitMutation) SetBranchName(s string) { + m.branch_name = &s +} + +// BranchName returns the value of the "branch_name" field in the mutation. +func (m *GitCommitMutation) BranchName() (r string, exists bool) { + v := m.branch_name + if v == nil { + return + } + return *v, true +} + +// OldBranchName returns the old "branch_name" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldBranchName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBranchName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBranchName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBranchName: %w", err) + } + return oldValue.BranchName, nil +} + +// ResetBranchName resets all changes to the "branch_name" field. +func (m *GitCommitMutation) ResetBranchName() { + m.branch_name = nil +} + +// SetRepoName sets the "repo_name" field. +func (m *GitCommitMutation) SetRepoName(s string) { + m.repo_name = &s +} + +// RepoName returns the value of the "repo_name" field in the mutation. +func (m *GitCommitMutation) RepoName() (r string, exists bool) { + v := m.repo_name + if v == nil { + return + } + return *v, true +} + +// OldRepoName returns the old "repo_name" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldRepoName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRepoName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRepoName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRepoName: %w", err) + } + return oldValue.RepoName, nil +} + +// ResetRepoName resets all changes to the "repo_name" field. +func (m *GitCommitMutation) ResetRepoName() { + m.repo_name = nil +} + +// SetCommitMessage sets the "commit_message" field. +func (m *GitCommitMutation) SetCommitMessage(s string) { + m.commit_message = &s +} + +// CommitMessage returns the value of the "commit_message" field in the mutation. +func (m *GitCommitMutation) CommitMessage() (r string, exists bool) { + v := m.commit_message + if v == nil { + return + } + return *v, true +} + +// OldCommitMessage returns the old "commit_message" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldCommitMessage(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCommitMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCommitMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCommitMessage: %w", err) + } + return oldValue.CommitMessage, nil +} + +// ResetCommitMessage resets all changes to the "commit_message" field. +func (m *GitCommitMutation) ResetCommitMessage() { + m.commit_message = nil +} + +// SetCommitTimestamp sets the "commit_timestamp" field. +func (m *GitCommitMutation) SetCommitTimestamp(t time.Time) { + m.commit_timestamp = &t +} + +// CommitTimestamp returns the value of the "commit_timestamp" field in the mutation. +func (m *GitCommitMutation) CommitTimestamp() (r time.Time, exists bool) { + v := m.commit_timestamp + if v == nil { + return + } + return *v, true +} + +// OldCommitTimestamp returns the old "commit_timestamp" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldCommitTimestamp(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCommitTimestamp is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCommitTimestamp requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCommitTimestamp: %w", err) + } + return oldValue.CommitTimestamp, nil +} + +// ResetCommitTimestamp resets all changes to the "commit_timestamp" field. +func (m *GitCommitMutation) ResetCommitTimestamp() { + m.commit_timestamp = nil +} + +// SetAuthor sets the "author" field. +func (m *GitCommitMutation) SetAuthor(s string) { + m.author = &s +} + +// Author returns the value of the "author" field in the mutation. +func (m *GitCommitMutation) Author() (r string, exists bool) { + v := m.author + if v == nil { + return + } + return *v, true +} + +// OldAuthor returns the old "author" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldAuthor(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthor is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthor requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthor: %w", err) + } + return oldValue.Author, nil +} + +// ClearAuthor clears the value of the "author" field. +func (m *GitCommitMutation) ClearAuthor() { + m.author = nil + m.clearedFields[gitcommit.FieldAuthor] = struct{}{} +} + +// AuthorCleared returns if the "author" field was cleared in this mutation. +func (m *GitCommitMutation) AuthorCleared() bool { + _, ok := m.clearedFields[gitcommit.FieldAuthor] + return ok +} + +// ResetAuthor resets all changes to the "author" field. +func (m *GitCommitMutation) ResetAuthor() { + m.author = nil + delete(m.clearedFields, gitcommit.FieldAuthor) +} + +// SetTimestamp sets the "timestamp" field. +func (m *GitCommitMutation) SetTimestamp(t time.Time) { + m.timestamp = &t +} + +// Timestamp returns the value of the "timestamp" field in the mutation. +func (m *GitCommitMutation) Timestamp() (r time.Time, exists bool) { + v := m.timestamp + if v == nil { + return + } + return *v, true +} + +// OldTimestamp returns the old "timestamp" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldTimestamp(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTimestamp is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTimestamp requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTimestamp: %w", err) + } + return oldValue.Timestamp, nil +} + +// ClearTimestamp clears the value of the "timestamp" field. +func (m *GitCommitMutation) ClearTimestamp() { + m.timestamp = nil + m.clearedFields[gitcommit.FieldTimestamp] = struct{}{} +} + +// TimestampCleared returns if the "timestamp" field was cleared in this mutation. +func (m *GitCommitMutation) TimestampCleared() bool { + _, ok := m.clearedFields[gitcommit.FieldTimestamp] + return ok +} + +// ResetTimestamp resets all changes to the "timestamp" field. +func (m *GitCommitMutation) ResetTimestamp() { + m.timestamp = nil + delete(m.clearedFields, gitcommit.FieldTimestamp) +} + +// SetPrNumber sets the "pr_number" field. +func (m *GitCommitMutation) SetPrNumber(s string) { + m.pr_number = &s +} + +// PrNumber returns the value of the "pr_number" field in the mutation. +func (m *GitCommitMutation) PrNumber() (r string, exists bool) { + v := m.pr_number + if v == nil { + return + } + return *v, true +} + +// OldPrNumber returns the old "pr_number" field's value of the GitCommit entity. +// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GitCommitMutation) OldPrNumber(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPrNumber is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPrNumber requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPrNumber: %w", err) + } + return oldValue.PrNumber, nil +} + +// ClearPrNumber clears the value of the "pr_number" field. +func (m *GitCommitMutation) ClearPrNumber() { + m.pr_number = nil + m.clearedFields[gitcommit.FieldPrNumber] = struct{}{} +} + +// PrNumberCleared returns if the "pr_number" field was cleared in this mutation. +func (m *GitCommitMutation) PrNumberCleared() bool { + _, ok := m.clearedFields[gitcommit.FieldPrNumber] + return ok +} + +// ResetPrNumber resets all changes to the "pr_number" field. +func (m *GitCommitMutation) ResetPrNumber() { + m.pr_number = nil + delete(m.clearedFields, gitcommit.FieldPrNumber) +} + +// AddResultIDs adds the "results" edge to the CIWorkflowResult entity by ids. +func (m *GitCommitMutation) AddResultIDs(ids ...uuid.UUID) { + if m.results == nil { + m.results = make(map[uuid.UUID]struct{}) + } + for i := range ids { + m.results[ids[i]] = struct{}{} + } +} + +// ClearResults clears the "results" edge to the CIWorkflowResult entity. +func (m *GitCommitMutation) ClearResults() { + m.clearedresults = true +} + +// ResultsCleared reports if the "results" edge to the CIWorkflowResult entity was cleared. +func (m *GitCommitMutation) ResultsCleared() bool { + return m.clearedresults +} + +// RemoveResultIDs removes the "results" edge to the CIWorkflowResult entity by IDs. +func (m *GitCommitMutation) RemoveResultIDs(ids ...uuid.UUID) { + if m.removedresults == nil { + m.removedresults = make(map[uuid.UUID]struct{}) + } + for i := range ids { + delete(m.results, ids[i]) + m.removedresults[ids[i]] = struct{}{} + } +} + +// RemovedResults returns the removed IDs of the "results" edge to the CIWorkflowResult entity. +func (m *GitCommitMutation) RemovedResultsIDs() (ids []uuid.UUID) { + for id := range m.removedresults { + ids = append(ids, id) + } + return +} + +// ResultsIDs returns the "results" edge IDs in the mutation. +func (m *GitCommitMutation) ResultsIDs() (ids []uuid.UUID) { + for id := range m.results { + ids = append(ids, id) + } + return +} + +// ResetResults resets all changes to the "results" edge. +func (m *GitCommitMutation) ResetResults() { + m.results = nil + m.clearedresults = false + m.removedresults = nil +} + +// Where appends a list predicates to the GitCommitMutation builder. +func (m *GitCommitMutation) Where(ps ...predicate.GitCommit) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the GitCommitMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *GitCommitMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.GitCommit, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *GitCommitMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *GitCommitMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (GitCommit). +func (m *GitCommitMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *GitCommitMutation) Fields() []string { + fields := make([]string, 0, 10) + if m.create_time != nil { + fields = append(fields, gitcommit.FieldCreateTime) + } + if m.update_time != nil { + fields = append(fields, gitcommit.FieldUpdateTime) + } + if m.commit_hash != nil { + fields = append(fields, gitcommit.FieldCommitHash) + } + if m.branch_name != nil { + fields = append(fields, gitcommit.FieldBranchName) + } + if m.repo_name != nil { + fields = append(fields, gitcommit.FieldRepoName) + } + if m.commit_message != nil { + fields = append(fields, gitcommit.FieldCommitMessage) + } + if m.commit_timestamp != nil { + fields = append(fields, gitcommit.FieldCommitTimestamp) + } + if m.author != nil { + fields = append(fields, gitcommit.FieldAuthor) + } + if m.timestamp != nil { + fields = append(fields, gitcommit.FieldTimestamp) + } + if m.pr_number != nil { + fields = append(fields, gitcommit.FieldPrNumber) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *GitCommitMutation) Field(name string) (ent.Value, bool) { + switch name { + case gitcommit.FieldCreateTime: + return m.CreateTime() + case gitcommit.FieldUpdateTime: + return m.UpdateTime() + case gitcommit.FieldCommitHash: + return m.CommitHash() + case gitcommit.FieldBranchName: + return m.BranchName() + case gitcommit.FieldRepoName: + return m.RepoName() + case gitcommit.FieldCommitMessage: + return m.CommitMessage() + case gitcommit.FieldCommitTimestamp: + return m.CommitTimestamp() + case gitcommit.FieldAuthor: + return m.Author() + case gitcommit.FieldTimestamp: + return m.Timestamp() + case gitcommit.FieldPrNumber: + return m.PrNumber() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *GitCommitMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case gitcommit.FieldCreateTime: + return m.OldCreateTime(ctx) + case gitcommit.FieldUpdateTime: + return m.OldUpdateTime(ctx) + case gitcommit.FieldCommitHash: + return m.OldCommitHash(ctx) + case gitcommit.FieldBranchName: + return m.OldBranchName(ctx) + case gitcommit.FieldRepoName: + return m.OldRepoName(ctx) + case gitcommit.FieldCommitMessage: + return m.OldCommitMessage(ctx) + case gitcommit.FieldCommitTimestamp: + return m.OldCommitTimestamp(ctx) + case gitcommit.FieldAuthor: + return m.OldAuthor(ctx) + case gitcommit.FieldTimestamp: + return m.OldTimestamp(ctx) + case gitcommit.FieldPrNumber: + return m.OldPrNumber(ctx) + } + return nil, fmt.Errorf("unknown GitCommit field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *GitCommitMutation) SetField(name string, value ent.Value) error { + switch name { + case gitcommit.FieldCreateTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreateTime(v) + return nil + case gitcommit.FieldUpdateTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdateTime(v) + return nil + case gitcommit.FieldCommitHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCommitHash(v) + return nil + case gitcommit.FieldBranchName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBranchName(v) + return nil + case gitcommit.FieldRepoName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRepoName(v) + return nil + case gitcommit.FieldCommitMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCommitMessage(v) + return nil + case gitcommit.FieldCommitTimestamp: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCommitTimestamp(v) + return nil + case gitcommit.FieldAuthor: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthor(v) + return nil + case gitcommit.FieldTimestamp: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTimestamp(v) + return nil + case gitcommit.FieldPrNumber: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPrNumber(v) + return nil + } + return fmt.Errorf("unknown GitCommit field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *GitCommitMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *GitCommitMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *GitCommitMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown GitCommit numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *GitCommitMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(gitcommit.FieldAuthor) { + fields = append(fields, gitcommit.FieldAuthor) + } + if m.FieldCleared(gitcommit.FieldTimestamp) { + fields = append(fields, gitcommit.FieldTimestamp) + } + if m.FieldCleared(gitcommit.FieldPrNumber) { + fields = append(fields, gitcommit.FieldPrNumber) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *GitCommitMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *GitCommitMutation) ClearField(name string) error { + switch name { + case gitcommit.FieldAuthor: + m.ClearAuthor() + return nil + case gitcommit.FieldTimestamp: + m.ClearTimestamp() + return nil + case gitcommit.FieldPrNumber: + m.ClearPrNumber() + return nil + } + return fmt.Errorf("unknown GitCommit nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *GitCommitMutation) ResetField(name string) error { + switch name { + case gitcommit.FieldCreateTime: + m.ResetCreateTime() + return nil + case gitcommit.FieldUpdateTime: + m.ResetUpdateTime() + return nil + case gitcommit.FieldCommitHash: + m.ResetCommitHash() + return nil + case gitcommit.FieldBranchName: + m.ResetBranchName() + return nil + case gitcommit.FieldRepoName: + m.ResetRepoName() + return nil + case gitcommit.FieldCommitMessage: + m.ResetCommitMessage() + return nil + case gitcommit.FieldCommitTimestamp: + m.ResetCommitTimestamp() + return nil + case gitcommit.FieldAuthor: + m.ResetAuthor() + return nil + case gitcommit.FieldTimestamp: + m.ResetTimestamp() + return nil + case gitcommit.FieldPrNumber: + m.ResetPrNumber() + return nil + } + return fmt.Errorf("unknown GitCommit field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *GitCommitMutation) AddedEdges() []string { + edges := make([]string, 0, 1) + if m.results != nil { + edges = append(edges, gitcommit.EdgeResults) + } + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *GitCommitMutation) AddedIDs(name string) []ent.Value { + switch name { + case gitcommit.EdgeResults: + ids := make([]ent.Value, 0, len(m.results)) + for id := range m.results { + ids = append(ids, id) + } + return ids + } + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *GitCommitMutation) RemovedEdges() []string { + edges := make([]string, 0, 1) + if m.removedresults != nil { + edges = append(edges, gitcommit.EdgeResults) + } + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *GitCommitMutation) RemovedIDs(name string) []ent.Value { + switch name { + case gitcommit.EdgeResults: + ids := make([]ent.Value, 0, len(m.removedresults)) + for id := range m.removedresults { + ids = append(ids, id) + } + return ids + } + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *GitCommitMutation) ClearedEdges() []string { + edges := make([]string, 0, 1) + if m.clearedresults { + edges = append(edges, gitcommit.EdgeResults) + } + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *GitCommitMutation) EdgeCleared(name string) bool { + switch name { + case gitcommit.EdgeResults: + return m.clearedresults + } + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *GitCommitMutation) ClearEdge(name string) error { + switch name { + } + return fmt.Errorf("unknown GitCommit unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *GitCommitMutation) ResetEdge(name string) error { + switch name { + case gitcommit.EdgeResults: + m.ResetResults() + return nil + } + return fmt.Errorf("unknown GitCommit edge %s", name) +} + +// NodeMutation represents an operation that mutates the Node nodes in the graph. +type NodeMutation struct { + config + op Op + typ string + id *string + create_time *time.Time + update_time *time.Time + name *string + description *string + category *string + author *string + license *string + repository_url *string + icon_url *string + tags *[]string + appendtags []string + total_install *int64 + addtotal_install *int64 + total_star *int64 + addtotal_star *int64 + total_review *int64 + addtotal_review *int64 + status *schema.NodeStatus + status_detail *string + clearedFields map[string]struct{} + publisher *string + clearedpublisher bool + versions map[uuid.UUID]struct{} + removedversions map[uuid.UUID]struct{} + clearedversions bool + reviews map[uuid.UUID]struct{} + removedreviews map[uuid.UUID]struct{} + clearedreviews bool + done bool + oldValue func(context.Context) (*Node, error) + predicates []predicate.Node +} + +var _ ent.Mutation = (*NodeMutation)(nil) + +// nodeOption allows management of the mutation configuration using functional options. +type nodeOption func(*NodeMutation) + +// newNodeMutation creates new mutation for the Node entity. +func newNodeMutation(c config, op Op, opts ...nodeOption) *NodeMutation { + m := &NodeMutation{ + config: c, + op: op, + typ: TypeNode, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withNodeID sets the ID field of the mutation. +func withNodeID(id string) nodeOption { + return func(m *NodeMutation) { + var ( + err error + once sync.Once + value *Node + ) + m.oldValue = func(ctx context.Context) (*Node, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().Node.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withNode sets the old Node of the mutation. +func withNode(node *Node) nodeOption { + return func(m *NodeMutation) { + m.oldValue = func(context.Context) (*Node, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m NodeMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m NodeMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of Node entities. +func (m *NodeMutation) SetID(id string) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *NodeMutation) ID() (id string, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *NodeMutation) IDs(ctx context.Context) ([]string, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []string{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().Node.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreateTime sets the "create_time" field. +func (m *NodeMutation) SetCreateTime(t time.Time) { + m.create_time = &t +} + +// CreateTime returns the value of the "create_time" field in the mutation. +func (m *NodeMutation) CreateTime() (r time.Time, exists bool) { + v := m.create_time + if v == nil { + return + } + return *v, true +} + +// OldCreateTime returns the old "create_time" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldCreateTime(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreateTime is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreateTime requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreateTime: %w", err) + } + return oldValue.CreateTime, nil +} + +// ResetCreateTime resets all changes to the "create_time" field. +func (m *NodeMutation) ResetCreateTime() { + m.create_time = nil +} + +// SetUpdateTime sets the "update_time" field. +func (m *NodeMutation) SetUpdateTime(t time.Time) { + m.update_time = &t +} + +// UpdateTime returns the value of the "update_time" field in the mutation. +func (m *NodeMutation) UpdateTime() (r time.Time, exists bool) { + v := m.update_time + if v == nil { + return + } + return *v, true +} + +// OldUpdateTime returns the old "update_time" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldUpdateTime(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdateTime is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdateTime requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdateTime: %w", err) + } + return oldValue.UpdateTime, nil +} + +// ResetUpdateTime resets all changes to the "update_time" field. +func (m *NodeMutation) ResetUpdateTime() { + m.update_time = nil +} + +// SetPublisherID sets the "publisher_id" field. +func (m *NodeMutation) SetPublisherID(s string) { + m.publisher = &s +} + +// PublisherID returns the value of the "publisher_id" field in the mutation. +func (m *NodeMutation) PublisherID() (r string, exists bool) { + v := m.publisher + if v == nil { + return + } + return *v, true +} + +// OldPublisherID returns the old "publisher_id" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldPublisherID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPublisherID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPublisherID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPublisherID: %w", err) + } + return oldValue.PublisherID, nil +} + +// ResetPublisherID resets all changes to the "publisher_id" field. +func (m *NodeMutation) ResetPublisherID() { + m.publisher = nil +} + +// SetName sets the "name" field. +func (m *NodeMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *NodeMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCommitMessage: %w", err) + return v, fmt.Errorf("querying old value for OldName: %w", err) } - return oldValue.CommitMessage, nil + return oldValue.Name, nil } -// ResetCommitMessage resets all changes to the "commit_message" field. -func (m *GitCommitMutation) ResetCommitMessage() { - m.commit_message = nil +// ResetName resets all changes to the "name" field. +func (m *NodeMutation) ResetName() { + m.name = nil } -// SetCommitTimestamp sets the "commit_timestamp" field. -func (m *GitCommitMutation) SetCommitTimestamp(t time.Time) { - m.commit_timestamp = &t +// SetDescription sets the "description" field. +func (m *NodeMutation) SetDescription(s string) { + m.description = &s } -// CommitTimestamp returns the value of the "commit_timestamp" field in the mutation. -func (m *GitCommitMutation) CommitTimestamp() (r time.Time, exists bool) { - v := m.commit_timestamp +// Description returns the value of the "description" field in the mutation. +func (m *NodeMutation) Description() (r string, exists bool) { + v := m.description if v == nil { return } return *v, true } -// OldCommitTimestamp returns the old "commit_timestamp" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// OldDescription returns the old "description" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldCommitTimestamp(ctx context.Context) (v time.Time, err error) { +func (m *NodeMutation) OldDescription(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCommitTimestamp is only allowed on UpdateOne operations") + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *NodeMutation) ClearDescription() { + m.description = nil + m.clearedFields[node.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *NodeMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[node.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *NodeMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, node.FieldDescription) +} + +// SetCategory sets the "category" field. +func (m *NodeMutation) SetCategory(s string) { + m.category = &s +} + +// Category returns the value of the "category" field in the mutation. +func (m *NodeMutation) Category() (r string, exists bool) { + v := m.category + if v == nil { + return + } + return *v, true +} + +// OldCategory returns the old "category" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldCategory(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCategory is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCategory requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCategory: %w", err) + } + return oldValue.Category, nil +} + +// ClearCategory clears the value of the "category" field. +func (m *NodeMutation) ClearCategory() { + m.category = nil + m.clearedFields[node.FieldCategory] = struct{}{} +} + +// CategoryCleared returns if the "category" field was cleared in this mutation. +func (m *NodeMutation) CategoryCleared() bool { + _, ok := m.clearedFields[node.FieldCategory] + return ok +} + +// ResetCategory resets all changes to the "category" field. +func (m *NodeMutation) ResetCategory() { + m.category = nil + delete(m.clearedFields, node.FieldCategory) +} + +// SetAuthor sets the "author" field. +func (m *NodeMutation) SetAuthor(s string) { + m.author = &s +} + +// Author returns the value of the "author" field in the mutation. +func (m *NodeMutation) Author() (r string, exists bool) { + v := m.author + if v == nil { + return + } + return *v, true +} + +// OldAuthor returns the old "author" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldAuthor(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAuthor is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAuthor requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAuthor: %w", err) + } + return oldValue.Author, nil +} + +// ClearAuthor clears the value of the "author" field. +func (m *NodeMutation) ClearAuthor() { + m.author = nil + m.clearedFields[node.FieldAuthor] = struct{}{} +} + +// AuthorCleared returns if the "author" field was cleared in this mutation. +func (m *NodeMutation) AuthorCleared() bool { + _, ok := m.clearedFields[node.FieldAuthor] + return ok +} + +// ResetAuthor resets all changes to the "author" field. +func (m *NodeMutation) ResetAuthor() { + m.author = nil + delete(m.clearedFields, node.FieldAuthor) +} + +// SetLicense sets the "license" field. +func (m *NodeMutation) SetLicense(s string) { + m.license = &s +} + +// License returns the value of the "license" field in the mutation. +func (m *NodeMutation) License() (r string, exists bool) { + v := m.license + if v == nil { + return + } + return *v, true +} + +// OldLicense returns the old "license" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldLicense(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLicense is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLicense requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLicense: %w", err) + } + return oldValue.License, nil +} + +// ResetLicense resets all changes to the "license" field. +func (m *NodeMutation) ResetLicense() { + m.license = nil +} + +// SetRepositoryURL sets the "repository_url" field. +func (m *NodeMutation) SetRepositoryURL(s string) { + m.repository_url = &s +} + +// RepositoryURL returns the value of the "repository_url" field in the mutation. +func (m *NodeMutation) RepositoryURL() (r string, exists bool) { + v := m.repository_url + if v == nil { + return + } + return *v, true +} + +// OldRepositoryURL returns the old "repository_url" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldRepositoryURL(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRepositoryURL is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRepositoryURL requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRepositoryURL: %w", err) + } + return oldValue.RepositoryURL, nil +} + +// ResetRepositoryURL resets all changes to the "repository_url" field. +func (m *NodeMutation) ResetRepositoryURL() { + m.repository_url = nil +} + +// SetIconURL sets the "icon_url" field. +func (m *NodeMutation) SetIconURL(s string) { + m.icon_url = &s +} + +// IconURL returns the value of the "icon_url" field in the mutation. +func (m *NodeMutation) IconURL() (r string, exists bool) { + v := m.icon_url + if v == nil { + return + } + return *v, true +} + +// OldIconURL returns the old "icon_url" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldIconURL(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIconURL is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCommitTimestamp requires an ID field in the mutation") + return v, errors.New("OldIconURL requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldCommitTimestamp: %w", err) + return v, fmt.Errorf("querying old value for OldIconURL: %w", err) } - return oldValue.CommitTimestamp, nil + return oldValue.IconURL, nil } -// ResetCommitTimestamp resets all changes to the "commit_timestamp" field. -func (m *GitCommitMutation) ResetCommitTimestamp() { - m.commit_timestamp = nil +// ClearIconURL clears the value of the "icon_url" field. +func (m *NodeMutation) ClearIconURL() { + m.icon_url = nil + m.clearedFields[node.FieldIconURL] = struct{}{} } -// SetAuthor sets the "author" field. -func (m *GitCommitMutation) SetAuthor(s string) { - m.author = &s +// IconURLCleared returns if the "icon_url" field was cleared in this mutation. +func (m *NodeMutation) IconURLCleared() bool { + _, ok := m.clearedFields[node.FieldIconURL] + return ok } -// Author returns the value of the "author" field in the mutation. -func (m *GitCommitMutation) Author() (r string, exists bool) { - v := m.author +// ResetIconURL resets all changes to the "icon_url" field. +func (m *NodeMutation) ResetIconURL() { + m.icon_url = nil + delete(m.clearedFields, node.FieldIconURL) +} + +// SetTags sets the "tags" field. +func (m *NodeMutation) SetTags(s []string) { + m.tags = &s + m.appendtags = nil +} + +// Tags returns the value of the "tags" field in the mutation. +func (m *NodeMutation) Tags() (r []string, exists bool) { + v := m.tags if v == nil { return } return *v, true } -// OldAuthor returns the old "author" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// OldTags returns the old "tags" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldAuthor(ctx context.Context) (v string, err error) { +func (m *NodeMutation) OldTags(ctx context.Context) (v []string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAuthor is only allowed on UpdateOne operations") + return v, errors.New("OldTags is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAuthor requires an ID field in the mutation") + return v, errors.New("OldTags requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldAuthor: %w", err) + return v, fmt.Errorf("querying old value for OldTags: %w", err) } - return oldValue.Author, nil + return oldValue.Tags, nil } -// ClearAuthor clears the value of the "author" field. -func (m *GitCommitMutation) ClearAuthor() { - m.author = nil - m.clearedFields[gitcommit.FieldAuthor] = struct{}{} +// AppendTags adds s to the "tags" field. +func (m *NodeMutation) AppendTags(s []string) { + m.appendtags = append(m.appendtags, s...) } -// AuthorCleared returns if the "author" field was cleared in this mutation. -func (m *GitCommitMutation) AuthorCleared() bool { - _, ok := m.clearedFields[gitcommit.FieldAuthor] - return ok +// AppendedTags returns the list of values that were appended to the "tags" field in this mutation. +func (m *NodeMutation) AppendedTags() ([]string, bool) { + if len(m.appendtags) == 0 { + return nil, false + } + return m.appendtags, true } -// ResetAuthor resets all changes to the "author" field. -func (m *GitCommitMutation) ResetAuthor() { - m.author = nil - delete(m.clearedFields, gitcommit.FieldAuthor) +// ResetTags resets all changes to the "tags" field. +func (m *NodeMutation) ResetTags() { + m.tags = nil + m.appendtags = nil } -// SetTimestamp sets the "timestamp" field. -func (m *GitCommitMutation) SetTimestamp(t time.Time) { - m.timestamp = &t +// SetTotalInstall sets the "total_install" field. +func (m *NodeMutation) SetTotalInstall(i int64) { + m.total_install = &i + m.addtotal_install = nil } -// Timestamp returns the value of the "timestamp" field in the mutation. -func (m *GitCommitMutation) Timestamp() (r time.Time, exists bool) { - v := m.timestamp +// TotalInstall returns the value of the "total_install" field in the mutation. +func (m *NodeMutation) TotalInstall() (r int64, exists bool) { + v := m.total_install if v == nil { return } return *v, true } -// OldTimestamp returns the old "timestamp" field's value of the GitCommit entity. -// If the GitCommit object wasn't provided to the builder, the object is fetched from the database. +// OldTotalInstall returns the old "total_install" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *GitCommitMutation) OldTimestamp(ctx context.Context) (v time.Time, err error) { +func (m *NodeMutation) OldTotalInstall(ctx context.Context) (v int64, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldTimestamp is only allowed on UpdateOne operations") + return v, errors.New("OldTotalInstall is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldTimestamp requires an ID field in the mutation") + return v, errors.New("OldTotalInstall requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldTimestamp: %w", err) + return v, fmt.Errorf("querying old value for OldTotalInstall: %w", err) } - return oldValue.Timestamp, nil -} - -// ClearTimestamp clears the value of the "timestamp" field. -func (m *GitCommitMutation) ClearTimestamp() { - m.timestamp = nil - m.clearedFields[gitcommit.FieldTimestamp] = struct{}{} -} - -// TimestampCleared returns if the "timestamp" field was cleared in this mutation. -func (m *GitCommitMutation) TimestampCleared() bool { - _, ok := m.clearedFields[gitcommit.FieldTimestamp] - return ok + return oldValue.TotalInstall, nil } -// ResetTimestamp resets all changes to the "timestamp" field. -func (m *GitCommitMutation) ResetTimestamp() { - m.timestamp = nil - delete(m.clearedFields, gitcommit.FieldTimestamp) +// AddTotalInstall adds i to the "total_install" field. +func (m *NodeMutation) AddTotalInstall(i int64) { + if m.addtotal_install != nil { + *m.addtotal_install += i + } else { + m.addtotal_install = &i + } } -// AddResultIDs adds the "results" edge to the CIWorkflowResult entity by ids. -func (m *GitCommitMutation) AddResultIDs(ids ...uuid.UUID) { - if m.results == nil { - m.results = make(map[uuid.UUID]struct{}) - } - for i := range ids { - m.results[ids[i]] = struct{}{} +// AddedTotalInstall returns the value that was added to the "total_install" field in this mutation. +func (m *NodeMutation) AddedTotalInstall() (r int64, exists bool) { + v := m.addtotal_install + if v == nil { + return } + return *v, true } -// ClearResults clears the "results" edge to the CIWorkflowResult entity. -func (m *GitCommitMutation) ClearResults() { - m.clearedresults = true +// ResetTotalInstall resets all changes to the "total_install" field. +func (m *NodeMutation) ResetTotalInstall() { + m.total_install = nil + m.addtotal_install = nil } -// ResultsCleared reports if the "results" edge to the CIWorkflowResult entity was cleared. -func (m *GitCommitMutation) ResultsCleared() bool { - return m.clearedresults +// SetTotalStar sets the "total_star" field. +func (m *NodeMutation) SetTotalStar(i int64) { + m.total_star = &i + m.addtotal_star = nil } -// RemoveResultIDs removes the "results" edge to the CIWorkflowResult entity by IDs. -func (m *GitCommitMutation) RemoveResultIDs(ids ...uuid.UUID) { - if m.removedresults == nil { - m.removedresults = make(map[uuid.UUID]struct{}) - } - for i := range ids { - delete(m.results, ids[i]) - m.removedresults[ids[i]] = struct{}{} +// TotalStar returns the value of the "total_star" field in the mutation. +func (m *NodeMutation) TotalStar() (r int64, exists bool) { + v := m.total_star + if v == nil { + return } + return *v, true } -// RemovedResults returns the removed IDs of the "results" edge to the CIWorkflowResult entity. -func (m *GitCommitMutation) RemovedResultsIDs() (ids []uuid.UUID) { - for id := range m.removedresults { - ids = append(ids, id) +// OldTotalStar returns the old "total_star" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldTotalStar(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotalStar is only allowed on UpdateOne operations") } - return -} - -// ResultsIDs returns the "results" edge IDs in the mutation. -func (m *GitCommitMutation) ResultsIDs() (ids []uuid.UUID) { - for id := range m.results { - ids = append(ids, id) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotalStar requires an ID field in the mutation") } - return -} - -// ResetResults resets all changes to the "results" edge. -func (m *GitCommitMutation) ResetResults() { - m.results = nil - m.clearedresults = false - m.removedresults = nil + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotalStar: %w", err) + } + return oldValue.TotalStar, nil } -// Where appends a list predicates to the GitCommitMutation builder. -func (m *GitCommitMutation) Where(ps ...predicate.GitCommit) { - m.predicates = append(m.predicates, ps...) +// AddTotalStar adds i to the "total_star" field. +func (m *NodeMutation) AddTotalStar(i int64) { + if m.addtotal_star != nil { + *m.addtotal_star += i + } else { + m.addtotal_star = &i + } } -// WhereP appends storage-level predicates to the GitCommitMutation builder. Using this method, -// users can use type-assertion to append predicates that do not depend on any generated package. -func (m *GitCommitMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.GitCommit, len(ps)) - for i := range ps { - p[i] = ps[i] +// AddedTotalStar returns the value that was added to the "total_star" field in this mutation. +func (m *NodeMutation) AddedTotalStar() (r int64, exists bool) { + v := m.addtotal_star + if v == nil { + return } - m.Where(p...) + return *v, true } -// Op returns the operation name. -func (m *GitCommitMutation) Op() Op { - return m.op +// ResetTotalStar resets all changes to the "total_star" field. +func (m *NodeMutation) ResetTotalStar() { + m.total_star = nil + m.addtotal_star = nil } -// SetOp allows setting the mutation operation. -func (m *GitCommitMutation) SetOp(op Op) { - m.op = op +// SetTotalReview sets the "total_review" field. +func (m *NodeMutation) SetTotalReview(i int64) { + m.total_review = &i + m.addtotal_review = nil } -// Type returns the node type of this mutation (GitCommit). -func (m *GitCommitMutation) Type() string { - return m.typ +// TotalReview returns the value of the "total_review" field in the mutation. +func (m *NodeMutation) TotalReview() (r int64, exists bool) { + v := m.total_review + if v == nil { + return + } + return *v, true } -// Fields returns all fields that were changed during this mutation. Note that in -// order to get all numeric fields that were incremented/decremented, call -// AddedFields(). -func (m *GitCommitMutation) Fields() []string { - fields := make([]string, 0, 9) - if m.create_time != nil { - fields = append(fields, gitcommit.FieldCreateTime) - } - if m.update_time != nil { - fields = append(fields, gitcommit.FieldUpdateTime) - } - if m.commit_hash != nil { - fields = append(fields, gitcommit.FieldCommitHash) - } - if m.branch_name != nil { - fields = append(fields, gitcommit.FieldBranchName) - } - if m.repo_name != nil { - fields = append(fields, gitcommit.FieldRepoName) - } - if m.commit_message != nil { - fields = append(fields, gitcommit.FieldCommitMessage) +// OldTotalReview returns the old "total_review" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldTotalReview(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTotalReview is only allowed on UpdateOne operations") } - if m.commit_timestamp != nil { - fields = append(fields, gitcommit.FieldCommitTimestamp) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTotalReview requires an ID field in the mutation") } - if m.author != nil { - fields = append(fields, gitcommit.FieldAuthor) + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTotalReview: %w", err) } - if m.timestamp != nil { - fields = append(fields, gitcommit.FieldTimestamp) + return oldValue.TotalReview, nil +} + +// AddTotalReview adds i to the "total_review" field. +func (m *NodeMutation) AddTotalReview(i int64) { + if m.addtotal_review != nil { + *m.addtotal_review += i + } else { + m.addtotal_review = &i } - return fields } -// Field returns the value of a field with the given name. The second boolean -// return value indicates that this field was not set, or was not defined in the -// schema. -func (m *GitCommitMutation) Field(name string) (ent.Value, bool) { - switch name { - case gitcommit.FieldCreateTime: - return m.CreateTime() - case gitcommit.FieldUpdateTime: - return m.UpdateTime() - case gitcommit.FieldCommitHash: - return m.CommitHash() - case gitcommit.FieldBranchName: - return m.BranchName() - case gitcommit.FieldRepoName: - return m.RepoName() - case gitcommit.FieldCommitMessage: - return m.CommitMessage() - case gitcommit.FieldCommitTimestamp: - return m.CommitTimestamp() - case gitcommit.FieldAuthor: - return m.Author() - case gitcommit.FieldTimestamp: - return m.Timestamp() +// AddedTotalReview returns the value that was added to the "total_review" field in this mutation. +func (m *NodeMutation) AddedTotalReview() (r int64, exists bool) { + v := m.addtotal_review + if v == nil { + return } - return nil, false + return *v, true } -// OldField returns the old value of the field from the database. An error is -// returned if the mutation operation is not UpdateOne, or the query to the -// database failed. -func (m *GitCommitMutation) OldField(ctx context.Context, name string) (ent.Value, error) { - switch name { - case gitcommit.FieldCreateTime: - return m.OldCreateTime(ctx) - case gitcommit.FieldUpdateTime: - return m.OldUpdateTime(ctx) - case gitcommit.FieldCommitHash: - return m.OldCommitHash(ctx) - case gitcommit.FieldBranchName: - return m.OldBranchName(ctx) - case gitcommit.FieldRepoName: - return m.OldRepoName(ctx) - case gitcommit.FieldCommitMessage: - return m.OldCommitMessage(ctx) - case gitcommit.FieldCommitTimestamp: - return m.OldCommitTimestamp(ctx) - case gitcommit.FieldAuthor: - return m.OldAuthor(ctx) - case gitcommit.FieldTimestamp: - return m.OldTimestamp(ctx) +// ResetTotalReview resets all changes to the "total_review" field. +func (m *NodeMutation) ResetTotalReview() { + m.total_review = nil + m.addtotal_review = nil +} + +// SetStatus sets the "status" field. +func (m *NodeMutation) SetStatus(ss schema.NodeStatus) { + m.status = &ss +} + +// Status returns the value of the "status" field in the mutation. +func (m *NodeMutation) Status() (r schema.NodeStatus, exists bool) { + v := m.status + if v == nil { + return } - return nil, fmt.Errorf("unknown GitCommit field %s", name) + return *v, true } -// SetField sets the value of a field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *GitCommitMutation) SetField(name string, value ent.Value) error { - switch name { - case gitcommit.FieldCreateTime: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreateTime(v) - return nil - case gitcommit.FieldUpdateTime: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdateTime(v) - return nil - case gitcommit.FieldCommitHash: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCommitHash(v) - return nil - case gitcommit.FieldBranchName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetBranchName(v) - return nil - case gitcommit.FieldRepoName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetRepoName(v) - return nil - case gitcommit.FieldCommitMessage: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCommitMessage(v) - return nil - case gitcommit.FieldCommitTimestamp: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCommitTimestamp(v) - return nil - case gitcommit.FieldAuthor: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAuthor(v) - return nil - case gitcommit.FieldTimestamp: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetTimestamp(v) - return nil +// OldStatus returns the old "status" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldStatus(ctx context.Context) (v schema.NodeStatus, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") } - return fmt.Errorf("unknown GitCommit field %s", name) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil } -// AddedFields returns all numeric fields that were incremented/decremented during -// this mutation. -func (m *GitCommitMutation) AddedFields() []string { - return nil +// ResetStatus resets all changes to the "status" field. +func (m *NodeMutation) ResetStatus() { + m.status = nil } -// AddedField returns the numeric value that was incremented/decremented on a field -// with the given name. The second boolean return value indicates that this field -// was not set, or was not defined in the schema. -func (m *GitCommitMutation) AddedField(name string) (ent.Value, bool) { - return nil, false +// SetStatusDetail sets the "status_detail" field. +func (m *NodeMutation) SetStatusDetail(s string) { + m.status_detail = &s } -// AddField adds the value to the field with the given name. It returns an error if -// the field is not defined in the schema, or if the type mismatched the field -// type. -func (m *GitCommitMutation) AddField(name string, value ent.Value) error { - switch name { +// StatusDetail returns the value of the "status_detail" field in the mutation. +func (m *NodeMutation) StatusDetail() (r string, exists bool) { + v := m.status_detail + if v == nil { + return } - return fmt.Errorf("unknown GitCommit numeric field %s", name) + return *v, true } -// ClearedFields returns all nullable fields that were cleared during this -// mutation. -func (m *GitCommitMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(gitcommit.FieldAuthor) { - fields = append(fields, gitcommit.FieldAuthor) +// OldStatusDetail returns the old "status_detail" field's value of the Node entity. +// If the Node object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeMutation) OldStatusDetail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatusDetail is only allowed on UpdateOne operations") } - if m.FieldCleared(gitcommit.FieldTimestamp) { - fields = append(fields, gitcommit.FieldTimestamp) + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatusDetail requires an ID field in the mutation") } - return fields + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatusDetail: %w", err) + } + return oldValue.StatusDetail, nil } -// FieldCleared returns a boolean indicating if a field with the given name was -// cleared in this mutation. -func (m *GitCommitMutation) FieldCleared(name string) bool { - _, ok := m.clearedFields[name] +// ClearStatusDetail clears the value of the "status_detail" field. +func (m *NodeMutation) ClearStatusDetail() { + m.status_detail = nil + m.clearedFields[node.FieldStatusDetail] = struct{}{} +} + +// StatusDetailCleared returns if the "status_detail" field was cleared in this mutation. +func (m *NodeMutation) StatusDetailCleared() bool { + _, ok := m.clearedFields[node.FieldStatusDetail] return ok } -// ClearField clears the value of the field with the given name. It returns an -// error if the field is not defined in the schema. -func (m *GitCommitMutation) ClearField(name string) error { - switch name { - case gitcommit.FieldAuthor: - m.ClearAuthor() - return nil - case gitcommit.FieldTimestamp: - m.ClearTimestamp() - return nil - } - return fmt.Errorf("unknown GitCommit nullable field %s", name) +// ResetStatusDetail resets all changes to the "status_detail" field. +func (m *NodeMutation) ResetStatusDetail() { + m.status_detail = nil + delete(m.clearedFields, node.FieldStatusDetail) } -// ResetField resets all changes in the mutation for the field with the given name. -// It returns an error if the field is not defined in the schema. -func (m *GitCommitMutation) ResetField(name string) error { - switch name { - case gitcommit.FieldCreateTime: - m.ResetCreateTime() - return nil - case gitcommit.FieldUpdateTime: - m.ResetUpdateTime() - return nil - case gitcommit.FieldCommitHash: - m.ResetCommitHash() - return nil - case gitcommit.FieldBranchName: - m.ResetBranchName() - return nil - case gitcommit.FieldRepoName: - m.ResetRepoName() - return nil - case gitcommit.FieldCommitMessage: - m.ResetCommitMessage() - return nil - case gitcommit.FieldCommitTimestamp: - m.ResetCommitTimestamp() - return nil - case gitcommit.FieldAuthor: - m.ResetAuthor() - return nil - case gitcommit.FieldTimestamp: - m.ResetTimestamp() - return nil - } - return fmt.Errorf("unknown GitCommit field %s", name) +// ClearPublisher clears the "publisher" edge to the Publisher entity. +func (m *NodeMutation) ClearPublisher() { + m.clearedpublisher = true + m.clearedFields[node.FieldPublisherID] = struct{}{} } -// AddedEdges returns all edge names that were set/added in this mutation. -func (m *GitCommitMutation) AddedEdges() []string { - edges := make([]string, 0, 1) - if m.results != nil { - edges = append(edges, gitcommit.EdgeResults) +// PublisherCleared reports if the "publisher" edge to the Publisher entity was cleared. +func (m *NodeMutation) PublisherCleared() bool { + return m.clearedpublisher +} + +// PublisherIDs returns the "publisher" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// PublisherID instead. It exists only for internal usage by the builders. +func (m *NodeMutation) PublisherIDs() (ids []string) { + if id := m.publisher; id != nil { + ids = append(ids, *id) } - return edges + return } -// AddedIDs returns all IDs (to other nodes) that were added for the given edge -// name in this mutation. -func (m *GitCommitMutation) AddedIDs(name string) []ent.Value { - switch name { - case gitcommit.EdgeResults: - ids := make([]ent.Value, 0, len(m.results)) - for id := range m.results { - ids = append(ids, id) - } - return ids +// ResetPublisher resets all changes to the "publisher" edge. +func (m *NodeMutation) ResetPublisher() { + m.publisher = nil + m.clearedpublisher = false +} + +// AddVersionIDs adds the "versions" edge to the NodeVersion entity by ids. +func (m *NodeMutation) AddVersionIDs(ids ...uuid.UUID) { + if m.versions == nil { + m.versions = make(map[uuid.UUID]struct{}) } - return nil + for i := range ids { + m.versions[ids[i]] = struct{}{} + } +} + +// ClearVersions clears the "versions" edge to the NodeVersion entity. +func (m *NodeMutation) ClearVersions() { + m.clearedversions = true +} + +// VersionsCleared reports if the "versions" edge to the NodeVersion entity was cleared. +func (m *NodeMutation) VersionsCleared() bool { + return m.clearedversions } -// RemovedEdges returns all edge names that were removed in this mutation. -func (m *GitCommitMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) - if m.removedresults != nil { - edges = append(edges, gitcommit.EdgeResults) +// RemoveVersionIDs removes the "versions" edge to the NodeVersion entity by IDs. +func (m *NodeMutation) RemoveVersionIDs(ids ...uuid.UUID) { + if m.removedversions == nil { + m.removedversions = make(map[uuid.UUID]struct{}) } - return edges -} - -// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with -// the given name in this mutation. -func (m *GitCommitMutation) RemovedIDs(name string) []ent.Value { - switch name { - case gitcommit.EdgeResults: - ids := make([]ent.Value, 0, len(m.removedresults)) - for id := range m.removedresults { - ids = append(ids, id) - } - return ids + for i := range ids { + delete(m.versions, ids[i]) + m.removedversions[ids[i]] = struct{}{} } - return nil } -// ClearedEdges returns all edge names that were cleared in this mutation. -func (m *GitCommitMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) - if m.clearedresults { - edges = append(edges, gitcommit.EdgeResults) +// RemovedVersions returns the removed IDs of the "versions" edge to the NodeVersion entity. +func (m *NodeMutation) RemovedVersionsIDs() (ids []uuid.UUID) { + for id := range m.removedversions { + ids = append(ids, id) } - return edges + return } -// EdgeCleared returns a boolean which indicates if the edge with the given name -// was cleared in this mutation. -func (m *GitCommitMutation) EdgeCleared(name string) bool { - switch name { - case gitcommit.EdgeResults: - return m.clearedresults +// VersionsIDs returns the "versions" edge IDs in the mutation. +func (m *NodeMutation) VersionsIDs() (ids []uuid.UUID) { + for id := range m.versions { + ids = append(ids, id) } - return false + return } -// ClearEdge clears the value of the edge with the given name. It returns an error -// if that edge is not defined in the schema. -func (m *GitCommitMutation) ClearEdge(name string) error { - switch name { - } - return fmt.Errorf("unknown GitCommit unique edge %s", name) +// ResetVersions resets all changes to the "versions" edge. +func (m *NodeMutation) ResetVersions() { + m.versions = nil + m.clearedversions = false + m.removedversions = nil } -// ResetEdge resets all changes to the edge with the given name in this mutation. -// It returns an error if the edge is not defined in the schema. -func (m *GitCommitMutation) ResetEdge(name string) error { - switch name { - case gitcommit.EdgeResults: - m.ResetResults() - return nil +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by ids. +func (m *NodeMutation) AddReviewIDs(ids ...uuid.UUID) { + if m.reviews == nil { + m.reviews = make(map[uuid.UUID]struct{}) + } + for i := range ids { + m.reviews[ids[i]] = struct{}{} } - return fmt.Errorf("unknown GitCommit edge %s", name) } -// NodeMutation represents an operation that mutates the Node nodes in the graph. -type NodeMutation struct { - config - op Op - typ string - id *string - create_time *time.Time - update_time *time.Time - name *string - description *string - author *string - license *string - repository_url *string - icon_url *string - tags *[]string - appendtags []string - clearedFields map[string]struct{} - publisher *string - clearedpublisher bool - versions map[uuid.UUID]struct{} - removedversions map[uuid.UUID]struct{} - clearedversions bool - done bool - oldValue func(context.Context) (*Node, error) - predicates []predicate.Node +// ClearReviews clears the "reviews" edge to the NodeReview entity. +func (m *NodeMutation) ClearReviews() { + m.clearedreviews = true } -var _ ent.Mutation = (*NodeMutation)(nil) - -// nodeOption allows management of the mutation configuration using functional options. -type nodeOption func(*NodeMutation) +// ReviewsCleared reports if the "reviews" edge to the NodeReview entity was cleared. +func (m *NodeMutation) ReviewsCleared() bool { + return m.clearedreviews +} -// newNodeMutation creates new mutation for the Node entity. -func newNodeMutation(c config, op Op, opts ...nodeOption) *NodeMutation { - m := &NodeMutation{ - config: c, - op: op, - typ: TypeNode, - clearedFields: make(map[string]struct{}), +// RemoveReviewIDs removes the "reviews" edge to the NodeReview entity by IDs. +func (m *NodeMutation) RemoveReviewIDs(ids ...uuid.UUID) { + if m.removedreviews == nil { + m.removedreviews = make(map[uuid.UUID]struct{}) } - for _, opt := range opts { - opt(m) + for i := range ids { + delete(m.reviews, ids[i]) + m.removedreviews[ids[i]] = struct{}{} } - return m } -// withNodeID sets the ID field of the mutation. -func withNodeID(id string) nodeOption { - return func(m *NodeMutation) { - var ( - err error - once sync.Once - value *Node - ) - m.oldValue = func(ctx context.Context) (*Node, error) { - once.Do(func() { - if m.done { - err = errors.New("querying old values post mutation is not allowed") - } else { - value, err = m.Client().Node.Get(ctx, id) - } - }) - return value, err - } - m.id = &id +// RemovedReviews returns the removed IDs of the "reviews" edge to the NodeReview entity. +func (m *NodeMutation) RemovedReviewsIDs() (ids []uuid.UUID) { + for id := range m.removedreviews { + ids = append(ids, id) } + return } -// withNode sets the old Node of the mutation. -func withNode(node *Node) nodeOption { - return func(m *NodeMutation) { - m.oldValue = func(context.Context) (*Node, error) { - return node, nil - } - m.id = &node.ID +// ReviewsIDs returns the "reviews" edge IDs in the mutation. +func (m *NodeMutation) ReviewsIDs() (ids []uuid.UUID) { + for id := range m.reviews { + ids = append(ids, id) } + return } -// Client returns a new `ent.Client` from the mutation. If the mutation was -// executed in a transaction (ent.Tx), a transactional client is returned. -func (m NodeMutation) Client() *Client { - client := &Client{config: m.config} - client.init() - return client -} - -// Tx returns an `ent.Tx` for mutations that were executed in transactions; -// it returns an error otherwise. -func (m NodeMutation) Tx() (*Tx, error) { - if _, ok := m.driver.(*txDriver); !ok { - return nil, errors.New("ent: mutation is not running in a transaction") - } - tx := &Tx{config: m.config} - tx.init() - return tx, nil +// ResetReviews resets all changes to the "reviews" edge. +func (m *NodeMutation) ResetReviews() { + m.reviews = nil + m.clearedreviews = false + m.removedreviews = nil } -// SetID sets the value of the id field. Note that this -// operation is only accepted on creation of Node entities. -func (m *NodeMutation) SetID(id string) { - m.id = &id +// Where appends a list predicates to the NodeMutation builder. +func (m *NodeMutation) Where(ps ...predicate.Node) { + m.predicates = append(m.predicates, ps...) } -// ID returns the ID value in the mutation. Note that the ID is only available -// if it was provided to the builder or after it was returned from the database. -func (m *NodeMutation) ID() (id string, exists bool) { - if m.id == nil { - return +// WhereP appends storage-level predicates to the NodeMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *NodeMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Node, len(ps)) + for i := range ps { + p[i] = ps[i] } - return *m.id, true + m.Where(p...) } -// IDs queries the database and returns the entity ids that match the mutation's predicate. -// That means, if the mutation is applied within a transaction with an isolation level such -// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated -// or updated by the mutation. -func (m *NodeMutation) IDs(ctx context.Context) ([]string, error) { - switch { - case m.op.Is(OpUpdateOne | OpDeleteOne): - id, exists := m.ID() - if exists { - return []string{id}, nil - } - fallthrough - case m.op.Is(OpUpdate | OpDelete): - return m.Client().Node.Query().Where(m.predicates...).IDs(ctx) - default: - return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) - } +// Op returns the operation name. +func (m *NodeMutation) Op() Op { + return m.op } -// SetCreateTime sets the "create_time" field. -func (m *NodeMutation) SetCreateTime(t time.Time) { - m.create_time = &t +// SetOp allows setting the mutation operation. +func (m *NodeMutation) SetOp(op Op) { + m.op = op } -// CreateTime returns the value of the "create_time" field in the mutation. -func (m *NodeMutation) CreateTime() (r time.Time, exists bool) { - v := m.create_time - if v == nil { - return - } - return *v, true +// Type returns the node type of this mutation (Node). +func (m *NodeMutation) Type() string { + return m.typ } -// OldCreateTime returns the old "create_time" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldCreateTime(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldCreateTime is only allowed on UpdateOne operations") +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *NodeMutation) Fields() []string { + fields := make([]string, 0, 16) + if m.create_time != nil { + fields = append(fields, node.FieldCreateTime) + } + if m.update_time != nil { + fields = append(fields, node.FieldUpdateTime) + } + if m.publisher != nil { + fields = append(fields, node.FieldPublisherID) + } + if m.name != nil { + fields = append(fields, node.FieldName) + } + if m.description != nil { + fields = append(fields, node.FieldDescription) + } + if m.category != nil { + fields = append(fields, node.FieldCategory) + } + if m.author != nil { + fields = append(fields, node.FieldAuthor) + } + if m.license != nil { + fields = append(fields, node.FieldLicense) + } + if m.repository_url != nil { + fields = append(fields, node.FieldRepositoryURL) + } + if m.icon_url != nil { + fields = append(fields, node.FieldIconURL) + } + if m.tags != nil { + fields = append(fields, node.FieldTags) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldCreateTime requires an ID field in the mutation") + if m.total_install != nil { + fields = append(fields, node.FieldTotalInstall) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldCreateTime: %w", err) + if m.total_star != nil { + fields = append(fields, node.FieldTotalStar) } - return oldValue.CreateTime, nil + if m.total_review != nil { + fields = append(fields, node.FieldTotalReview) + } + if m.status != nil { + fields = append(fields, node.FieldStatus) + } + if m.status_detail != nil { + fields = append(fields, node.FieldStatusDetail) + } + return fields } -// ResetCreateTime resets all changes to the "create_time" field. -func (m *NodeMutation) ResetCreateTime() { - m.create_time = nil +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *NodeMutation) Field(name string) (ent.Value, bool) { + switch name { + case node.FieldCreateTime: + return m.CreateTime() + case node.FieldUpdateTime: + return m.UpdateTime() + case node.FieldPublisherID: + return m.PublisherID() + case node.FieldName: + return m.Name() + case node.FieldDescription: + return m.Description() + case node.FieldCategory: + return m.Category() + case node.FieldAuthor: + return m.Author() + case node.FieldLicense: + return m.License() + case node.FieldRepositoryURL: + return m.RepositoryURL() + case node.FieldIconURL: + return m.IconURL() + case node.FieldTags: + return m.Tags() + case node.FieldTotalInstall: + return m.TotalInstall() + case node.FieldTotalStar: + return m.TotalStar() + case node.FieldTotalReview: + return m.TotalReview() + case node.FieldStatus: + return m.Status() + case node.FieldStatusDetail: + return m.StatusDetail() + } + return nil, false } -// SetUpdateTime sets the "update_time" field. -func (m *NodeMutation) SetUpdateTime(t time.Time) { - m.update_time = &t +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *NodeMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case node.FieldCreateTime: + return m.OldCreateTime(ctx) + case node.FieldUpdateTime: + return m.OldUpdateTime(ctx) + case node.FieldPublisherID: + return m.OldPublisherID(ctx) + case node.FieldName: + return m.OldName(ctx) + case node.FieldDescription: + return m.OldDescription(ctx) + case node.FieldCategory: + return m.OldCategory(ctx) + case node.FieldAuthor: + return m.OldAuthor(ctx) + case node.FieldLicense: + return m.OldLicense(ctx) + case node.FieldRepositoryURL: + return m.OldRepositoryURL(ctx) + case node.FieldIconURL: + return m.OldIconURL(ctx) + case node.FieldTags: + return m.OldTags(ctx) + case node.FieldTotalInstall: + return m.OldTotalInstall(ctx) + case node.FieldTotalStar: + return m.OldTotalStar(ctx) + case node.FieldTotalReview: + return m.OldTotalReview(ctx) + case node.FieldStatus: + return m.OldStatus(ctx) + case node.FieldStatusDetail: + return m.OldStatusDetail(ctx) + } + return nil, fmt.Errorf("unknown Node field %s", name) } -// UpdateTime returns the value of the "update_time" field in the mutation. -func (m *NodeMutation) UpdateTime() (r time.Time, exists bool) { - v := m.update_time - if v == nil { - return +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *NodeMutation) SetField(name string, value ent.Value) error { + switch name { + case node.FieldCreateTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreateTime(v) + return nil + case node.FieldUpdateTime: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdateTime(v) + return nil + case node.FieldPublisherID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPublisherID(v) + return nil + case node.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case node.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + case node.FieldCategory: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCategory(v) + return nil + case node.FieldAuthor: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAuthor(v) + return nil + case node.FieldLicense: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLicense(v) + return nil + case node.FieldRepositoryURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRepositoryURL(v) + return nil + case node.FieldIconURL: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIconURL(v) + return nil + case node.FieldTags: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTags(v) + return nil + case node.FieldTotalInstall: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalInstall(v) + return nil + case node.FieldTotalStar: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalStar(v) + return nil + case node.FieldTotalReview: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTotalReview(v) + return nil + case node.FieldStatus: + v, ok := value.(schema.NodeStatus) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case node.FieldStatusDetail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatusDetail(v) + return nil } - return *v, true + return fmt.Errorf("unknown Node field %s", name) } -// OldUpdateTime returns the old "update_time" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldUpdateTime(ctx context.Context) (v time.Time, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldUpdateTime is only allowed on UpdateOne operations") +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *NodeMutation) AddedFields() []string { + var fields []string + if m.addtotal_install != nil { + fields = append(fields, node.FieldTotalInstall) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldUpdateTime requires an ID field in the mutation") + if m.addtotal_star != nil { + fields = append(fields, node.FieldTotalStar) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldUpdateTime: %w", err) + if m.addtotal_review != nil { + fields = append(fields, node.FieldTotalReview) } - return oldValue.UpdateTime, nil -} - -// ResetUpdateTime resets all changes to the "update_time" field. -func (m *NodeMutation) ResetUpdateTime() { - m.update_time = nil + return fields } -// SetPublisherID sets the "publisher_id" field. -func (m *NodeMutation) SetPublisherID(s string) { - m.publisher = &s +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *NodeMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case node.FieldTotalInstall: + return m.AddedTotalInstall() + case node.FieldTotalStar: + return m.AddedTotalStar() + case node.FieldTotalReview: + return m.AddedTotalReview() + } + return nil, false } -// PublisherID returns the value of the "publisher_id" field in the mutation. -func (m *NodeMutation) PublisherID() (r string, exists bool) { - v := m.publisher - if v == nil { - return +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *NodeMutation) AddField(name string, value ent.Value) error { + switch name { + case node.FieldTotalInstall: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalInstall(v) + return nil + case node.FieldTotalStar: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalStar(v) + return nil + case node.FieldTotalReview: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddTotalReview(v) + return nil } - return *v, true + return fmt.Errorf("unknown Node numeric field %s", name) } -// OldPublisherID returns the old "publisher_id" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldPublisherID(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldPublisherID is only allowed on UpdateOne operations") +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *NodeMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(node.FieldDescription) { + fields = append(fields, node.FieldDescription) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldPublisherID requires an ID field in the mutation") + if m.FieldCleared(node.FieldCategory) { + fields = append(fields, node.FieldCategory) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldPublisherID: %w", err) + if m.FieldCleared(node.FieldAuthor) { + fields = append(fields, node.FieldAuthor) } - return oldValue.PublisherID, nil + if m.FieldCleared(node.FieldIconURL) { + fields = append(fields, node.FieldIconURL) + } + if m.FieldCleared(node.FieldStatusDetail) { + fields = append(fields, node.FieldStatusDetail) + } + return fields } -// ResetPublisherID resets all changes to the "publisher_id" field. -func (m *NodeMutation) ResetPublisherID() { - m.publisher = nil +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *NodeMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok } -// SetName sets the "name" field. -func (m *NodeMutation) SetName(s string) { - m.name = &s +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *NodeMutation) ClearField(name string) error { + switch name { + case node.FieldDescription: + m.ClearDescription() + return nil + case node.FieldCategory: + m.ClearCategory() + return nil + case node.FieldAuthor: + m.ClearAuthor() + return nil + case node.FieldIconURL: + m.ClearIconURL() + return nil + case node.FieldStatusDetail: + m.ClearStatusDetail() + return nil + } + return fmt.Errorf("unknown Node nullable field %s", name) } -// Name returns the value of the "name" field in the mutation. -func (m *NodeMutation) Name() (r string, exists bool) { - v := m.name - if v == nil { - return +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *NodeMutation) ResetField(name string) error { + switch name { + case node.FieldCreateTime: + m.ResetCreateTime() + return nil + case node.FieldUpdateTime: + m.ResetUpdateTime() + return nil + case node.FieldPublisherID: + m.ResetPublisherID() + return nil + case node.FieldName: + m.ResetName() + return nil + case node.FieldDescription: + m.ResetDescription() + return nil + case node.FieldCategory: + m.ResetCategory() + return nil + case node.FieldAuthor: + m.ResetAuthor() + return nil + case node.FieldLicense: + m.ResetLicense() + return nil + case node.FieldRepositoryURL: + m.ResetRepositoryURL() + return nil + case node.FieldIconURL: + m.ResetIconURL() + return nil + case node.FieldTags: + m.ResetTags() + return nil + case node.FieldTotalInstall: + m.ResetTotalInstall() + return nil + case node.FieldTotalStar: + m.ResetTotalStar() + return nil + case node.FieldTotalReview: + m.ResetTotalReview() + return nil + case node.FieldStatus: + m.ResetStatus() + return nil + case node.FieldStatusDetail: + m.ResetStatusDetail() + return nil } - return *v, true + return fmt.Errorf("unknown Node field %s", name) } -// OldName returns the old "name" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldName(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldName is only allowed on UpdateOne operations") +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *NodeMutation) AddedEdges() []string { + edges := make([]string, 0, 3) + if m.publisher != nil { + edges = append(edges, node.EdgePublisher) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldName requires an ID field in the mutation") + if m.versions != nil { + edges = append(edges, node.EdgeVersions) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldName: %w", err) + if m.reviews != nil { + edges = append(edges, node.EdgeReviews) } - return oldValue.Name, nil + return edges } -// ResetName resets all changes to the "name" field. -func (m *NodeMutation) ResetName() { - m.name = nil +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *NodeMutation) AddedIDs(name string) []ent.Value { + switch name { + case node.EdgePublisher: + if id := m.publisher; id != nil { + return []ent.Value{*id} + } + case node.EdgeVersions: + ids := make([]ent.Value, 0, len(m.versions)) + for id := range m.versions { + ids = append(ids, id) + } + return ids + case node.EdgeReviews: + ids := make([]ent.Value, 0, len(m.reviews)) + for id := range m.reviews { + ids = append(ids, id) + } + return ids + } + return nil } -// SetDescription sets the "description" field. -func (m *NodeMutation) SetDescription(s string) { - m.description = &s +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *NodeMutation) RemovedEdges() []string { + edges := make([]string, 0, 3) + if m.removedversions != nil { + edges = append(edges, node.EdgeVersions) + } + if m.removedreviews != nil { + edges = append(edges, node.EdgeReviews) + } + return edges } -// Description returns the value of the "description" field in the mutation. -func (m *NodeMutation) Description() (r string, exists bool) { - v := m.description - if v == nil { - return +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *NodeMutation) RemovedIDs(name string) []ent.Value { + switch name { + case node.EdgeVersions: + ids := make([]ent.Value, 0, len(m.removedversions)) + for id := range m.removedversions { + ids = append(ids, id) + } + return ids + case node.EdgeReviews: + ids := make([]ent.Value, 0, len(m.removedreviews)) + for id := range m.removedreviews { + ids = append(ids, id) + } + return ids } - return *v, true + return nil } -// OldDescription returns the old "description" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldDescription(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldDescription is only allowed on UpdateOne operations") +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *NodeMutation) ClearedEdges() []string { + edges := make([]string, 0, 3) + if m.clearedpublisher { + edges = append(edges, node.EdgePublisher) } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldDescription requires an ID field in the mutation") + if m.clearedversions { + edges = append(edges, node.EdgeVersions) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldDescription: %w", err) + if m.clearedreviews { + edges = append(edges, node.EdgeReviews) } - return oldValue.Description, nil + return edges } -// ClearDescription clears the value of the "description" field. -func (m *NodeMutation) ClearDescription() { - m.description = nil - m.clearedFields[node.FieldDescription] = struct{}{} +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *NodeMutation) EdgeCleared(name string) bool { + switch name { + case node.EdgePublisher: + return m.clearedpublisher + case node.EdgeVersions: + return m.clearedversions + case node.EdgeReviews: + return m.clearedreviews + } + return false } -// DescriptionCleared returns if the "description" field was cleared in this mutation. -func (m *NodeMutation) DescriptionCleared() bool { - _, ok := m.clearedFields[node.FieldDescription] - return ok +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *NodeMutation) ClearEdge(name string) error { + switch name { + case node.EdgePublisher: + m.ClearPublisher() + return nil + } + return fmt.Errorf("unknown Node unique edge %s", name) } -// ResetDescription resets all changes to the "description" field. -func (m *NodeMutation) ResetDescription() { - m.description = nil - delete(m.clearedFields, node.FieldDescription) +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *NodeMutation) ResetEdge(name string) error { + switch name { + case node.EdgePublisher: + m.ResetPublisher() + return nil + case node.EdgeVersions: + m.ResetVersions() + return nil + case node.EdgeReviews: + m.ResetReviews() + return nil + } + return fmt.Errorf("unknown Node edge %s", name) } -// SetAuthor sets the "author" field. -func (m *NodeMutation) SetAuthor(s string) { - m.author = &s +// NodeReviewMutation represents an operation that mutates the NodeReview nodes in the graph. +type NodeReviewMutation struct { + config + op Op + typ string + id *uuid.UUID + star *int + addstar *int + clearedFields map[string]struct{} + user *string + cleareduser bool + node *string + clearednode bool + done bool + oldValue func(context.Context) (*NodeReview, error) + predicates []predicate.NodeReview } -// Author returns the value of the "author" field in the mutation. -func (m *NodeMutation) Author() (r string, exists bool) { - v := m.author - if v == nil { - return - } - return *v, true -} +var _ ent.Mutation = (*NodeReviewMutation)(nil) -// OldAuthor returns the old "author" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldAuthor(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldAuthor is only allowed on UpdateOne operations") +// nodereviewOption allows management of the mutation configuration using functional options. +type nodereviewOption func(*NodeReviewMutation) + +// newNodeReviewMutation creates new mutation for the NodeReview entity. +func newNodeReviewMutation(c config, op Op, opts ...nodereviewOption) *NodeReviewMutation { + m := &NodeReviewMutation{ + config: c, + op: op, + typ: TypeNodeReview, + clearedFields: make(map[string]struct{}), } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldAuthor requires an ID field in the mutation") + for _, opt := range opts { + opt(m) } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldAuthor: %w", err) + return m +} + +// withNodeReviewID sets the ID field of the mutation. +func withNodeReviewID(id uuid.UUID) nodereviewOption { + return func(m *NodeReviewMutation) { + var ( + err error + once sync.Once + value *NodeReview + ) + m.oldValue = func(ctx context.Context) (*NodeReview, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().NodeReview.Get(ctx, id) + } + }) + return value, err + } + m.id = &id } - return oldValue.Author, nil } -// ClearAuthor clears the value of the "author" field. -func (m *NodeMutation) ClearAuthor() { - m.author = nil - m.clearedFields[node.FieldAuthor] = struct{}{} +// withNodeReview sets the old NodeReview of the mutation. +func withNodeReview(node *NodeReview) nodereviewOption { + return func(m *NodeReviewMutation) { + m.oldValue = func(context.Context) (*NodeReview, error) { + return node, nil + } + m.id = &node.ID + } } -// AuthorCleared returns if the "author" field was cleared in this mutation. -func (m *NodeMutation) AuthorCleared() bool { - _, ok := m.clearedFields[node.FieldAuthor] - return ok +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m NodeReviewMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client } -// ResetAuthor resets all changes to the "author" field. -func (m *NodeMutation) ResetAuthor() { - m.author = nil - delete(m.clearedFields, node.FieldAuthor) +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m NodeReviewMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil } -// SetLicense sets the "license" field. -func (m *NodeMutation) SetLicense(s string) { - m.license = &s +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of NodeReview entities. +func (m *NodeReviewMutation) SetID(id uuid.UUID) { + m.id = &id } -// License returns the value of the "license" field in the mutation. -func (m *NodeMutation) License() (r string, exists bool) { - v := m.license - if v == nil { +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *NodeReviewMutation) ID() (id uuid.UUID, exists bool) { + if m.id == nil { return } - return *v, true + return *m.id, true } -// OldLicense returns the old "license" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. -// An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldLicense(ctx context.Context) (v string, err error) { - if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldLicense is only allowed on UpdateOne operations") - } - if m.id == nil || m.oldValue == nil { - return v, errors.New("OldLicense requires an ID field in the mutation") - } - oldValue, err := m.oldValue(ctx) - if err != nil { - return v, fmt.Errorf("querying old value for OldLicense: %w", err) +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *NodeReviewMutation) IDs(ctx context.Context) ([]uuid.UUID, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []uuid.UUID{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().NodeReview.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) } - return oldValue.License, nil -} - -// ResetLicense resets all changes to the "license" field. -func (m *NodeMutation) ResetLicense() { - m.license = nil } -// SetRepositoryURL sets the "repository_url" field. -func (m *NodeMutation) SetRepositoryURL(s string) { - m.repository_url = &s +// SetNodeID sets the "node_id" field. +func (m *NodeReviewMutation) SetNodeID(s string) { + m.node = &s } -// RepositoryURL returns the value of the "repository_url" field in the mutation. -func (m *NodeMutation) RepositoryURL() (r string, exists bool) { - v := m.repository_url +// NodeID returns the value of the "node_id" field in the mutation. +func (m *NodeReviewMutation) NodeID() (r string, exists bool) { + v := m.node if v == nil { return } return *v, true } -// OldRepositoryURL returns the old "repository_url" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. +// OldNodeID returns the old "node_id" field's value of the NodeReview entity. +// If the NodeReview object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldRepositoryURL(ctx context.Context) (v string, err error) { +func (m *NodeReviewMutation) OldNodeID(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldRepositoryURL is only allowed on UpdateOne operations") + return v, errors.New("OldNodeID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldRepositoryURL requires an ID field in the mutation") + return v, errors.New("OldNodeID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldRepositoryURL: %w", err) + return v, fmt.Errorf("querying old value for OldNodeID: %w", err) } - return oldValue.RepositoryURL, nil + return oldValue.NodeID, nil } -// ResetRepositoryURL resets all changes to the "repository_url" field. -func (m *NodeMutation) ResetRepositoryURL() { - m.repository_url = nil +// ResetNodeID resets all changes to the "node_id" field. +func (m *NodeReviewMutation) ResetNodeID() { + m.node = nil } -// SetIconURL sets the "icon_url" field. -func (m *NodeMutation) SetIconURL(s string) { - m.icon_url = &s +// SetUserID sets the "user_id" field. +func (m *NodeReviewMutation) SetUserID(s string) { + m.user = &s } -// IconURL returns the value of the "icon_url" field in the mutation. -func (m *NodeMutation) IconURL() (r string, exists bool) { - v := m.icon_url +// UserID returns the value of the "user_id" field in the mutation. +func (m *NodeReviewMutation) UserID() (r string, exists bool) { + v := m.user if v == nil { return } return *v, true } -// OldIconURL returns the old "icon_url" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. +// OldUserID returns the old "user_id" field's value of the NodeReview entity. +// If the NodeReview object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldIconURL(ctx context.Context) (v string, err error) { +func (m *NodeReviewMutation) OldUserID(ctx context.Context) (v string, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldIconURL is only allowed on UpdateOne operations") + return v, errors.New("OldUserID is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldIconURL requires an ID field in the mutation") + return v, errors.New("OldUserID requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldIconURL: %w", err) + return v, fmt.Errorf("querying old value for OldUserID: %w", err) } - return oldValue.IconURL, nil -} - -// ClearIconURL clears the value of the "icon_url" field. -func (m *NodeMutation) ClearIconURL() { - m.icon_url = nil - m.clearedFields[node.FieldIconURL] = struct{}{} -} - -// IconURLCleared returns if the "icon_url" field was cleared in this mutation. -func (m *NodeMutation) IconURLCleared() bool { - _, ok := m.clearedFields[node.FieldIconURL] - return ok + return oldValue.UserID, nil } -// ResetIconURL resets all changes to the "icon_url" field. -func (m *NodeMutation) ResetIconURL() { - m.icon_url = nil - delete(m.clearedFields, node.FieldIconURL) +// ResetUserID resets all changes to the "user_id" field. +func (m *NodeReviewMutation) ResetUserID() { + m.user = nil } -// SetTags sets the "tags" field. -func (m *NodeMutation) SetTags(s []string) { - m.tags = &s - m.appendtags = nil +// SetStar sets the "star" field. +func (m *NodeReviewMutation) SetStar(i int) { + m.star = &i + m.addstar = nil } -// Tags returns the value of the "tags" field in the mutation. -func (m *NodeMutation) Tags() (r []string, exists bool) { - v := m.tags +// Star returns the value of the "star" field in the mutation. +func (m *NodeReviewMutation) Star() (r int, exists bool) { + v := m.star if v == nil { return } return *v, true } -// OldTags returns the old "tags" field's value of the Node entity. -// If the Node object wasn't provided to the builder, the object is fetched from the database. +// OldStar returns the old "star" field's value of the NodeReview entity. +// If the NodeReview object wasn't provided to the builder, the object is fetched from the database. // An error is returned if the mutation operation is not UpdateOne, or the database query fails. -func (m *NodeMutation) OldTags(ctx context.Context) (v []string, err error) { +func (m *NodeReviewMutation) OldStar(ctx context.Context) (v int, err error) { if !m.op.Is(OpUpdateOne) { - return v, errors.New("OldTags is only allowed on UpdateOne operations") + return v, errors.New("OldStar is only allowed on UpdateOne operations") } if m.id == nil || m.oldValue == nil { - return v, errors.New("OldTags requires an ID field in the mutation") + return v, errors.New("OldStar requires an ID field in the mutation") } oldValue, err := m.oldValue(ctx) if err != nil { - return v, fmt.Errorf("querying old value for OldTags: %w", err) - } - return oldValue.Tags, nil -} - -// AppendTags adds s to the "tags" field. -func (m *NodeMutation) AppendTags(s []string) { - m.appendtags = append(m.appendtags, s...) -} - -// AppendedTags returns the list of values that were appended to the "tags" field in this mutation. -func (m *NodeMutation) AppendedTags() ([]string, bool) { - if len(m.appendtags) == 0 { - return nil, false + return v, fmt.Errorf("querying old value for OldStar: %w", err) } - return m.appendtags, true -} - -// ResetTags resets all changes to the "tags" field. -func (m *NodeMutation) ResetTags() { - m.tags = nil - m.appendtags = nil -} - -// ClearPublisher clears the "publisher" edge to the Publisher entity. -func (m *NodeMutation) ClearPublisher() { - m.clearedpublisher = true - m.clearedFields[node.FieldPublisherID] = struct{}{} -} - -// PublisherCleared reports if the "publisher" edge to the Publisher entity was cleared. -func (m *NodeMutation) PublisherCleared() bool { - return m.clearedpublisher + return oldValue.Star, nil } -// PublisherIDs returns the "publisher" edge IDs in the mutation. -// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use -// PublisherID instead. It exists only for internal usage by the builders. -func (m *NodeMutation) PublisherIDs() (ids []string) { - if id := m.publisher; id != nil { - ids = append(ids, *id) +// AddStar adds i to the "star" field. +func (m *NodeReviewMutation) AddStar(i int) { + if m.addstar != nil { + *m.addstar += i + } else { + m.addstar = &i } - return -} - -// ResetPublisher resets all changes to the "publisher" edge. -func (m *NodeMutation) ResetPublisher() { - m.publisher = nil - m.clearedpublisher = false } -// AddVersionIDs adds the "versions" edge to the NodeVersion entity by ids. -func (m *NodeMutation) AddVersionIDs(ids ...uuid.UUID) { - if m.versions == nil { - m.versions = make(map[uuid.UUID]struct{}) - } - for i := range ids { - m.versions[ids[i]] = struct{}{} +// AddedStar returns the value that was added to the "star" field in this mutation. +func (m *NodeReviewMutation) AddedStar() (r int, exists bool) { + v := m.addstar + if v == nil { + return } + return *v, true } -// ClearVersions clears the "versions" edge to the NodeVersion entity. -func (m *NodeMutation) ClearVersions() { - m.clearedversions = true +// ResetStar resets all changes to the "star" field. +func (m *NodeReviewMutation) ResetStar() { + m.star = nil + m.addstar = nil } -// VersionsCleared reports if the "versions" edge to the NodeVersion entity was cleared. -func (m *NodeMutation) VersionsCleared() bool { - return m.clearedversions +// ClearUser clears the "user" edge to the User entity. +func (m *NodeReviewMutation) ClearUser() { + m.cleareduser = true + m.clearedFields[nodereview.FieldUserID] = struct{}{} } -// RemoveVersionIDs removes the "versions" edge to the NodeVersion entity by IDs. -func (m *NodeMutation) RemoveVersionIDs(ids ...uuid.UUID) { - if m.removedversions == nil { - m.removedversions = make(map[uuid.UUID]struct{}) - } - for i := range ids { - delete(m.versions, ids[i]) - m.removedversions[ids[i]] = struct{}{} - } +// UserCleared reports if the "user" edge to the User entity was cleared. +func (m *NodeReviewMutation) UserCleared() bool { + return m.cleareduser } -// RemovedVersions returns the removed IDs of the "versions" edge to the NodeVersion entity. -func (m *NodeMutation) RemovedVersionsIDs() (ids []uuid.UUID) { - for id := range m.removedversions { - ids = append(ids, id) +// UserIDs returns the "user" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// UserID instead. It exists only for internal usage by the builders. +func (m *NodeReviewMutation) UserIDs() (ids []string) { + if id := m.user; id != nil { + ids = append(ids, *id) } return } -// VersionsIDs returns the "versions" edge IDs in the mutation. -func (m *NodeMutation) VersionsIDs() (ids []uuid.UUID) { - for id := range m.versions { - ids = append(ids, id) +// ResetUser resets all changes to the "user" edge. +func (m *NodeReviewMutation) ResetUser() { + m.user = nil + m.cleareduser = false +} + +// ClearNode clears the "node" edge to the Node entity. +func (m *NodeReviewMutation) ClearNode() { + m.clearednode = true + m.clearedFields[nodereview.FieldNodeID] = struct{}{} +} + +// NodeCleared reports if the "node" edge to the Node entity was cleared. +func (m *NodeReviewMutation) NodeCleared() bool { + return m.clearednode +} + +// NodeIDs returns the "node" edge IDs in the mutation. +// Note that IDs always returns len(IDs) <= 1 for unique edges, and you should use +// NodeID instead. It exists only for internal usage by the builders. +func (m *NodeReviewMutation) NodeIDs() (ids []string) { + if id := m.node; id != nil { + ids = append(ids, *id) } return } -// ResetVersions resets all changes to the "versions" edge. -func (m *NodeMutation) ResetVersions() { - m.versions = nil - m.clearedversions = false - m.removedversions = nil +// ResetNode resets all changes to the "node" edge. +func (m *NodeReviewMutation) ResetNode() { + m.node = nil + m.clearednode = false } -// Where appends a list predicates to the NodeMutation builder. -func (m *NodeMutation) Where(ps ...predicate.Node) { +// Where appends a list predicates to the NodeReviewMutation builder. +func (m *NodeReviewMutation) Where(ps ...predicate.NodeReview) { m.predicates = append(m.predicates, ps...) } -// WhereP appends storage-level predicates to the NodeMutation builder. Using this method, +// WhereP appends storage-level predicates to the NodeReviewMutation builder. Using this method, // users can use type-assertion to append predicates that do not depend on any generated package. -func (m *NodeMutation) WhereP(ps ...func(*sql.Selector)) { - p := make([]predicate.Node, len(ps)) +func (m *NodeReviewMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.NodeReview, len(ps)) for i := range ps { p[i] = ps[i] } @@ -2736,54 +4655,33 @@ func (m *NodeMutation) WhereP(ps ...func(*sql.Selector)) { } // Op returns the operation name. -func (m *NodeMutation) Op() Op { +func (m *NodeReviewMutation) Op() Op { return m.op } // SetOp allows setting the mutation operation. -func (m *NodeMutation) SetOp(op Op) { +func (m *NodeReviewMutation) SetOp(op Op) { m.op = op } -// Type returns the node type of this mutation (Node). -func (m *NodeMutation) Type() string { +// Type returns the node type of this mutation (NodeReview). +func (m *NodeReviewMutation) Type() string { return m.typ } // Fields returns all fields that were changed during this mutation. Note that in // order to get all numeric fields that were incremented/decremented, call // AddedFields(). -func (m *NodeMutation) Fields() []string { - fields := make([]string, 0, 10) - if m.create_time != nil { - fields = append(fields, node.FieldCreateTime) - } - if m.update_time != nil { - fields = append(fields, node.FieldUpdateTime) - } - if m.publisher != nil { - fields = append(fields, node.FieldPublisherID) - } - if m.name != nil { - fields = append(fields, node.FieldName) - } - if m.description != nil { - fields = append(fields, node.FieldDescription) - } - if m.author != nil { - fields = append(fields, node.FieldAuthor) - } - if m.license != nil { - fields = append(fields, node.FieldLicense) - } - if m.repository_url != nil { - fields = append(fields, node.FieldRepositoryURL) +func (m *NodeReviewMutation) Fields() []string { + fields := make([]string, 0, 3) + if m.node != nil { + fields = append(fields, nodereview.FieldNodeID) } - if m.icon_url != nil { - fields = append(fields, node.FieldIconURL) + if m.user != nil { + fields = append(fields, nodereview.FieldUserID) } - if m.tags != nil { - fields = append(fields, node.FieldTags) + if m.star != nil { + fields = append(fields, nodereview.FieldStar) } return fields } @@ -2791,28 +4689,14 @@ func (m *NodeMutation) Fields() []string { // Field returns the value of a field with the given name. The second boolean // return value indicates that this field was not set, or was not defined in the // schema. -func (m *NodeMutation) Field(name string) (ent.Value, bool) { +func (m *NodeReviewMutation) Field(name string) (ent.Value, bool) { switch name { - case node.FieldCreateTime: - return m.CreateTime() - case node.FieldUpdateTime: - return m.UpdateTime() - case node.FieldPublisherID: - return m.PublisherID() - case node.FieldName: - return m.Name() - case node.FieldDescription: - return m.Description() - case node.FieldAuthor: - return m.Author() - case node.FieldLicense: - return m.License() - case node.FieldRepositoryURL: - return m.RepositoryURL() - case node.FieldIconURL: - return m.IconURL() - case node.FieldTags: - return m.Tags() + case nodereview.FieldNodeID: + return m.NodeID() + case nodereview.FieldUserID: + return m.UserID() + case nodereview.FieldStar: + return m.Star() } return nil, false } @@ -2820,311 +4704,211 @@ func (m *NodeMutation) Field(name string) (ent.Value, bool) { // OldField returns the old value of the field from the database. An error is // returned if the mutation operation is not UpdateOne, or the query to the // database failed. -func (m *NodeMutation) OldField(ctx context.Context, name string) (ent.Value, error) { +func (m *NodeReviewMutation) OldField(ctx context.Context, name string) (ent.Value, error) { switch name { - case node.FieldCreateTime: - return m.OldCreateTime(ctx) - case node.FieldUpdateTime: - return m.OldUpdateTime(ctx) - case node.FieldPublisherID: - return m.OldPublisherID(ctx) - case node.FieldName: - return m.OldName(ctx) - case node.FieldDescription: - return m.OldDescription(ctx) - case node.FieldAuthor: - return m.OldAuthor(ctx) - case node.FieldLicense: - return m.OldLicense(ctx) - case node.FieldRepositoryURL: - return m.OldRepositoryURL(ctx) - case node.FieldIconURL: - return m.OldIconURL(ctx) - case node.FieldTags: - return m.OldTags(ctx) + case nodereview.FieldNodeID: + return m.OldNodeID(ctx) + case nodereview.FieldUserID: + return m.OldUserID(ctx) + case nodereview.FieldStar: + return m.OldStar(ctx) } - return nil, fmt.Errorf("unknown Node field %s", name) + return nil, fmt.Errorf("unknown NodeReview field %s", name) } // SetField sets the value of a field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *NodeMutation) SetField(name string, value ent.Value) error { +func (m *NodeReviewMutation) SetField(name string, value ent.Value) error { switch name { - case node.FieldCreateTime: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetCreateTime(v) - return nil - case node.FieldUpdateTime: - v, ok := value.(time.Time) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetUpdateTime(v) - return nil - case node.FieldPublisherID: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetPublisherID(v) - return nil - case node.FieldName: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetName(v) - return nil - case node.FieldDescription: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetDescription(v) - return nil - case node.FieldAuthor: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetAuthor(v) - return nil - case node.FieldLicense: - v, ok := value.(string) - if !ok { - return fmt.Errorf("unexpected type %T for field %s", value, name) - } - m.SetLicense(v) - return nil - case node.FieldRepositoryURL: + case nodereview.FieldNodeID: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetRepositoryURL(v) + m.SetNodeID(v) return nil - case node.FieldIconURL: + case nodereview.FieldUserID: v, ok := value.(string) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetIconURL(v) + m.SetUserID(v) return nil - case node.FieldTags: - v, ok := value.([]string) + case nodereview.FieldStar: + v, ok := value.(int) if !ok { return fmt.Errorf("unexpected type %T for field %s", value, name) } - m.SetTags(v) + m.SetStar(v) return nil } - return fmt.Errorf("unknown Node field %s", name) + return fmt.Errorf("unknown NodeReview field %s", name) } // AddedFields returns all numeric fields that were incremented/decremented during // this mutation. -func (m *NodeMutation) AddedFields() []string { - return nil +func (m *NodeReviewMutation) AddedFields() []string { + var fields []string + if m.addstar != nil { + fields = append(fields, nodereview.FieldStar) + } + return fields } // AddedField returns the numeric value that was incremented/decremented on a field // with the given name. The second boolean return value indicates that this field // was not set, or was not defined in the schema. -func (m *NodeMutation) AddedField(name string) (ent.Value, bool) { +func (m *NodeReviewMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case nodereview.FieldStar: + return m.AddedStar() + } return nil, false } // AddField adds the value to the field with the given name. It returns an error if // the field is not defined in the schema, or if the type mismatched the field // type. -func (m *NodeMutation) AddField(name string, value ent.Value) error { +func (m *NodeReviewMutation) AddField(name string, value ent.Value) error { switch name { + case nodereview.FieldStar: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddStar(v) + return nil } - return fmt.Errorf("unknown Node numeric field %s", name) + return fmt.Errorf("unknown NodeReview numeric field %s", name) } // ClearedFields returns all nullable fields that were cleared during this // mutation. -func (m *NodeMutation) ClearedFields() []string { - var fields []string - if m.FieldCleared(node.FieldDescription) { - fields = append(fields, node.FieldDescription) - } - if m.FieldCleared(node.FieldAuthor) { - fields = append(fields, node.FieldAuthor) - } - if m.FieldCleared(node.FieldIconURL) { - fields = append(fields, node.FieldIconURL) - } - return fields +func (m *NodeReviewMutation) ClearedFields() []string { + return nil } // FieldCleared returns a boolean indicating if a field with the given name was // cleared in this mutation. -func (m *NodeMutation) FieldCleared(name string) bool { +func (m *NodeReviewMutation) FieldCleared(name string) bool { _, ok := m.clearedFields[name] return ok } // ClearField clears the value of the field with the given name. It returns an // error if the field is not defined in the schema. -func (m *NodeMutation) ClearField(name string) error { - switch name { - case node.FieldDescription: - m.ClearDescription() - return nil - case node.FieldAuthor: - m.ClearAuthor() - return nil - case node.FieldIconURL: - m.ClearIconURL() - return nil - } - return fmt.Errorf("unknown Node nullable field %s", name) +func (m *NodeReviewMutation) ClearField(name string) error { + return fmt.Errorf("unknown NodeReview nullable field %s", name) } // ResetField resets all changes in the mutation for the field with the given name. // It returns an error if the field is not defined in the schema. -func (m *NodeMutation) ResetField(name string) error { +func (m *NodeReviewMutation) ResetField(name string) error { switch name { - case node.FieldCreateTime: - m.ResetCreateTime() - return nil - case node.FieldUpdateTime: - m.ResetUpdateTime() - return nil - case node.FieldPublisherID: - m.ResetPublisherID() - return nil - case node.FieldName: - m.ResetName() - return nil - case node.FieldDescription: - m.ResetDescription() - return nil - case node.FieldAuthor: - m.ResetAuthor() - return nil - case node.FieldLicense: - m.ResetLicense() - return nil - case node.FieldRepositoryURL: - m.ResetRepositoryURL() + case nodereview.FieldNodeID: + m.ResetNodeID() return nil - case node.FieldIconURL: - m.ResetIconURL() + case nodereview.FieldUserID: + m.ResetUserID() return nil - case node.FieldTags: - m.ResetTags() + case nodereview.FieldStar: + m.ResetStar() return nil } - return fmt.Errorf("unknown Node field %s", name) + return fmt.Errorf("unknown NodeReview field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. -func (m *NodeMutation) AddedEdges() []string { - edges := make([]string, 0, 2) - if m.publisher != nil { - edges = append(edges, node.EdgePublisher) +func (m *NodeReviewMutation) AddedEdges() []string { + edges := make([]string, 0, 2) + if m.user != nil { + edges = append(edges, nodereview.EdgeUser) } - if m.versions != nil { - edges = append(edges, node.EdgeVersions) + if m.node != nil { + edges = append(edges, nodereview.EdgeNode) } return edges } // AddedIDs returns all IDs (to other nodes) that were added for the given edge // name in this mutation. -func (m *NodeMutation) AddedIDs(name string) []ent.Value { +func (m *NodeReviewMutation) AddedIDs(name string) []ent.Value { switch name { - case node.EdgePublisher: - if id := m.publisher; id != nil { + case nodereview.EdgeUser: + if id := m.user; id != nil { return []ent.Value{*id} } - case node.EdgeVersions: - ids := make([]ent.Value, 0, len(m.versions)) - for id := range m.versions { - ids = append(ids, id) + case nodereview.EdgeNode: + if id := m.node; id != nil { + return []ent.Value{*id} } - return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. -func (m *NodeMutation) RemovedEdges() []string { +func (m *NodeReviewMutation) RemovedEdges() []string { edges := make([]string, 0, 2) - if m.removedversions != nil { - edges = append(edges, node.EdgeVersions) - } return edges } // RemovedIDs returns all IDs (to other nodes) that were removed for the edge with // the given name in this mutation. -func (m *NodeMutation) RemovedIDs(name string) []ent.Value { - switch name { - case node.EdgeVersions: - ids := make([]ent.Value, 0, len(m.removedversions)) - for id := range m.removedversions { - ids = append(ids, id) - } - return ids - } +func (m *NodeReviewMutation) RemovedIDs(name string) []ent.Value { return nil } // ClearedEdges returns all edge names that were cleared in this mutation. -func (m *NodeMutation) ClearedEdges() []string { +func (m *NodeReviewMutation) ClearedEdges() []string { edges := make([]string, 0, 2) - if m.clearedpublisher { - edges = append(edges, node.EdgePublisher) + if m.cleareduser { + edges = append(edges, nodereview.EdgeUser) } - if m.clearedversions { - edges = append(edges, node.EdgeVersions) + if m.clearednode { + edges = append(edges, nodereview.EdgeNode) } return edges } // EdgeCleared returns a boolean which indicates if the edge with the given name // was cleared in this mutation. -func (m *NodeMutation) EdgeCleared(name string) bool { +func (m *NodeReviewMutation) EdgeCleared(name string) bool { switch name { - case node.EdgePublisher: - return m.clearedpublisher - case node.EdgeVersions: - return m.clearedversions + case nodereview.EdgeUser: + return m.cleareduser + case nodereview.EdgeNode: + return m.clearednode } return false } // ClearEdge clears the value of the edge with the given name. It returns an error // if that edge is not defined in the schema. -func (m *NodeMutation) ClearEdge(name string) error { +func (m *NodeReviewMutation) ClearEdge(name string) error { switch name { - case node.EdgePublisher: - m.ClearPublisher() + case nodereview.EdgeUser: + m.ClearUser() + return nil + case nodereview.EdgeNode: + m.ClearNode() return nil } - return fmt.Errorf("unknown Node unique edge %s", name) + return fmt.Errorf("unknown NodeReview unique edge %s", name) } // ResetEdge resets all changes to the edge with the given name in this mutation. // It returns an error if the edge is not defined in the schema. -func (m *NodeMutation) ResetEdge(name string) error { +func (m *NodeReviewMutation) ResetEdge(name string) error { switch name { - case node.EdgePublisher: - m.ResetPublisher() + case nodereview.EdgeUser: + m.ResetUser() return nil - case node.EdgeVersions: - m.ResetVersions() + case nodereview.EdgeNode: + m.ResetNode() return nil } - return fmt.Errorf("unknown Node edge %s", name) + return fmt.Errorf("unknown NodeReview edge %s", name) } // NodeVersionMutation represents an operation that mutates the NodeVersion nodes in the graph. @@ -3140,6 +4924,8 @@ type NodeVersionMutation struct { pip_dependencies *[]string appendpip_dependencies []string deprecated *bool + status *schema.NodeVersionStatus + status_reason *string clearedFields map[string]struct{} node *string clearednode bool @@ -3534,6 +5320,78 @@ func (m *NodeVersionMutation) ResetDeprecated() { m.deprecated = nil } +// SetStatus sets the "status" field. +func (m *NodeVersionMutation) SetStatus(svs schema.NodeVersionStatus) { + m.status = &svs +} + +// Status returns the value of the "status" field in the mutation. +func (m *NodeVersionMutation) Status() (r schema.NodeVersionStatus, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the NodeVersion entity. +// If the NodeVersion object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeVersionMutation) OldStatus(ctx context.Context) (v schema.NodeVersionStatus, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *NodeVersionMutation) ResetStatus() { + m.status = nil +} + +// SetStatusReason sets the "status_reason" field. +func (m *NodeVersionMutation) SetStatusReason(s string) { + m.status_reason = &s +} + +// StatusReason returns the value of the "status_reason" field in the mutation. +func (m *NodeVersionMutation) StatusReason() (r string, exists bool) { + v := m.status_reason + if v == nil { + return + } + return *v, true +} + +// OldStatusReason returns the old "status_reason" field's value of the NodeVersion entity. +// If the NodeVersion object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *NodeVersionMutation) OldStatusReason(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatusReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatusReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatusReason: %w", err) + } + return oldValue.StatusReason, nil +} + +// ResetStatusReason resets all changes to the "status_reason" field. +func (m *NodeVersionMutation) ResetStatusReason() { + m.status_reason = nil +} + // ClearNode clears the "node" edge to the Node entity. func (m *NodeVersionMutation) ClearNode() { m.clearednode = true @@ -3634,7 +5492,7 @@ func (m *NodeVersionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *NodeVersionMutation) Fields() []string { - fields := make([]string, 0, 7) + fields := make([]string, 0, 9) if m.create_time != nil { fields = append(fields, nodeversion.FieldCreateTime) } @@ -3656,6 +5514,12 @@ func (m *NodeVersionMutation) Fields() []string { if m.deprecated != nil { fields = append(fields, nodeversion.FieldDeprecated) } + if m.status != nil { + fields = append(fields, nodeversion.FieldStatus) + } + if m.status_reason != nil { + fields = append(fields, nodeversion.FieldStatusReason) + } return fields } @@ -3678,6 +5542,10 @@ func (m *NodeVersionMutation) Field(name string) (ent.Value, bool) { return m.PipDependencies() case nodeversion.FieldDeprecated: return m.Deprecated() + case nodeversion.FieldStatus: + return m.Status() + case nodeversion.FieldStatusReason: + return m.StatusReason() } return nil, false } @@ -3701,6 +5569,10 @@ func (m *NodeVersionMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldPipDependencies(ctx) case nodeversion.FieldDeprecated: return m.OldDeprecated(ctx) + case nodeversion.FieldStatus: + return m.OldStatus(ctx) + case nodeversion.FieldStatusReason: + return m.OldStatusReason(ctx) } return nil, fmt.Errorf("unknown NodeVersion field %s", name) } @@ -3759,6 +5631,20 @@ func (m *NodeVersionMutation) SetField(name string, value ent.Value) error { } m.SetDeprecated(v) return nil + case nodeversion.FieldStatus: + v, ok := value.(schema.NodeVersionStatus) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case nodeversion.FieldStatusReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatusReason(v) + return nil } return fmt.Errorf("unknown NodeVersion field %s", name) } @@ -3838,6 +5724,12 @@ func (m *NodeVersionMutation) ResetField(name string) error { case nodeversion.FieldDeprecated: m.ResetDeprecated() return nil + case nodeversion.FieldStatus: + m.ResetStatus() + return nil + case nodeversion.FieldStatusReason: + m.ResetStatusReason() + return nil } return fmt.Errorf("unknown NodeVersion field %s", name) } @@ -4604,6 +6496,7 @@ type PublisherMutation struct { support_email *string source_code_repo *string logo_url *string + status *schema.PublisherStatusType clearedFields map[string]struct{} publisher_permissions map[int]struct{} removedpublisher_permissions map[int]struct{} @@ -5076,6 +6969,42 @@ func (m *PublisherMutation) ResetLogoURL() { delete(m.clearedFields, publisher.FieldLogoURL) } +// SetStatus sets the "status" field. +func (m *PublisherMutation) SetStatus(sst schema.PublisherStatusType) { + m.status = &sst +} + +// Status returns the value of the "status" field in the mutation. +func (m *PublisherMutation) Status() (r schema.PublisherStatusType, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the Publisher entity. +// If the Publisher object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *PublisherMutation) OldStatus(ctx context.Context) (v schema.PublisherStatusType, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *PublisherMutation) ResetStatus() { + m.status = nil +} + // AddPublisherPermissionIDs adds the "publisher_permissions" edge to the PublisherPermission entity by ids. func (m *PublisherMutation) AddPublisherPermissionIDs(ids ...int) { if m.publisher_permissions == nil { @@ -5272,7 +7201,7 @@ func (m *PublisherMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *PublisherMutation) Fields() []string { - fields := make([]string, 0, 8) + fields := make([]string, 0, 9) if m.create_time != nil { fields = append(fields, publisher.FieldCreateTime) } @@ -5297,6 +7226,9 @@ func (m *PublisherMutation) Fields() []string { if m.logo_url != nil { fields = append(fields, publisher.FieldLogoURL) } + if m.status != nil { + fields = append(fields, publisher.FieldStatus) + } return fields } @@ -5321,6 +7253,8 @@ func (m *PublisherMutation) Field(name string) (ent.Value, bool) { return m.SourceCodeRepo() case publisher.FieldLogoURL: return m.LogoURL() + case publisher.FieldStatus: + return m.Status() } return nil, false } @@ -5346,6 +7280,8 @@ func (m *PublisherMutation) OldField(ctx context.Context, name string) (ent.Valu return m.OldSourceCodeRepo(ctx) case publisher.FieldLogoURL: return m.OldLogoURL(ctx) + case publisher.FieldStatus: + return m.OldStatus(ctx) } return nil, fmt.Errorf("unknown Publisher field %s", name) } @@ -5411,6 +7347,13 @@ func (m *PublisherMutation) SetField(name string, value ent.Value) error { } m.SetLogoURL(v) return nil + case publisher.FieldStatus: + v, ok := value.(schema.PublisherStatusType) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil } return fmt.Errorf("unknown Publisher field %s", name) } @@ -5517,6 +7460,9 @@ func (m *PublisherMutation) ResetField(name string) error { case publisher.FieldLogoURL: m.ResetLogoURL() return nil + case publisher.FieldStatus: + m.ResetStatus() + return nil } return fmt.Errorf("unknown Publisher field %s", name) } @@ -6900,10 +8846,14 @@ type UserMutation struct { name *string is_approved *bool is_admin *bool + status *schema.UserStatusType clearedFields map[string]struct{} publisher_permissions map[int]struct{} removedpublisher_permissions map[int]struct{} clearedpublisher_permissions bool + reviews map[uuid.UUID]struct{} + removedreviews map[uuid.UUID]struct{} + clearedreviews bool done bool oldValue func(context.Context) (*User, error) predicates []predicate.User @@ -7255,6 +9205,42 @@ func (m *UserMutation) ResetIsAdmin() { m.is_admin = nil } +// SetStatus sets the "status" field. +func (m *UserMutation) SetStatus(sst schema.UserStatusType) { + m.status = &sst +} + +// Status returns the value of the "status" field in the mutation. +func (m *UserMutation) Status() (r schema.UserStatusType, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldStatus(ctx context.Context) (v schema.UserStatusType, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *UserMutation) ResetStatus() { + m.status = nil +} + // AddPublisherPermissionIDs adds the "publisher_permissions" edge to the PublisherPermission entity by ids. func (m *UserMutation) AddPublisherPermissionIDs(ids ...int) { if m.publisher_permissions == nil { @@ -7309,6 +9295,60 @@ func (m *UserMutation) ResetPublisherPermissions() { m.removedpublisher_permissions = nil } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by ids. +func (m *UserMutation) AddReviewIDs(ids ...uuid.UUID) { + if m.reviews == nil { + m.reviews = make(map[uuid.UUID]struct{}) + } + for i := range ids { + m.reviews[ids[i]] = struct{}{} + } +} + +// ClearReviews clears the "reviews" edge to the NodeReview entity. +func (m *UserMutation) ClearReviews() { + m.clearedreviews = true +} + +// ReviewsCleared reports if the "reviews" edge to the NodeReview entity was cleared. +func (m *UserMutation) ReviewsCleared() bool { + return m.clearedreviews +} + +// RemoveReviewIDs removes the "reviews" edge to the NodeReview entity by IDs. +func (m *UserMutation) RemoveReviewIDs(ids ...uuid.UUID) { + if m.removedreviews == nil { + m.removedreviews = make(map[uuid.UUID]struct{}) + } + for i := range ids { + delete(m.reviews, ids[i]) + m.removedreviews[ids[i]] = struct{}{} + } +} + +// RemovedReviews returns the removed IDs of the "reviews" edge to the NodeReview entity. +func (m *UserMutation) RemovedReviewsIDs() (ids []uuid.UUID) { + for id := range m.removedreviews { + ids = append(ids, id) + } + return +} + +// ReviewsIDs returns the "reviews" edge IDs in the mutation. +func (m *UserMutation) ReviewsIDs() (ids []uuid.UUID) { + for id := range m.reviews { + ids = append(ids, id) + } + return +} + +// ResetReviews resets all changes to the "reviews" edge. +func (m *UserMutation) ResetReviews() { + m.reviews = nil + m.clearedreviews = false + m.removedreviews = nil +} + // Where appends a list predicates to the UserMutation builder. func (m *UserMutation) Where(ps ...predicate.User) { m.predicates = append(m.predicates, ps...) @@ -7343,7 +9383,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 6) + fields := make([]string, 0, 7) if m.create_time != nil { fields = append(fields, user.FieldCreateTime) } @@ -7362,6 +9402,9 @@ func (m *UserMutation) Fields() []string { if m.is_admin != nil { fields = append(fields, user.FieldIsAdmin) } + if m.status != nil { + fields = append(fields, user.FieldStatus) + } return fields } @@ -7382,6 +9425,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.IsApproved() case user.FieldIsAdmin: return m.IsAdmin() + case user.FieldStatus: + return m.Status() } return nil, false } @@ -7403,6 +9448,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldIsApproved(ctx) case user.FieldIsAdmin: return m.OldIsAdmin(ctx) + case user.FieldStatus: + return m.OldStatus(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -7454,6 +9501,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetIsAdmin(v) return nil + case user.FieldStatus: + v, ok := value.(schema.UserStatusType) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -7536,16 +9590,22 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldIsAdmin: m.ResetIsAdmin() return nil + case user.FieldStatus: + m.ResetStatus() + return nil } return fmt.Errorf("unknown User field %s", name) } // AddedEdges returns all edge names that were set/added in this mutation. func (m *UserMutation) AddedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 2) if m.publisher_permissions != nil { edges = append(edges, user.EdgePublisherPermissions) } + if m.reviews != nil { + edges = append(edges, user.EdgeReviews) + } return edges } @@ -7559,16 +9619,25 @@ func (m *UserMutation) AddedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeReviews: + ids := make([]ent.Value, 0, len(m.reviews)) + for id := range m.reviews { + ids = append(ids, id) + } + return ids } return nil } // RemovedEdges returns all edge names that were removed in this mutation. func (m *UserMutation) RemovedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 2) if m.removedpublisher_permissions != nil { edges = append(edges, user.EdgePublisherPermissions) } + if m.removedreviews != nil { + edges = append(edges, user.EdgeReviews) + } return edges } @@ -7582,16 +9651,25 @@ func (m *UserMutation) RemovedIDs(name string) []ent.Value { ids = append(ids, id) } return ids + case user.EdgeReviews: + ids := make([]ent.Value, 0, len(m.removedreviews)) + for id := range m.removedreviews { + ids = append(ids, id) + } + return ids } return nil } // ClearedEdges returns all edge names that were cleared in this mutation. func (m *UserMutation) ClearedEdges() []string { - edges := make([]string, 0, 1) + edges := make([]string, 0, 2) if m.clearedpublisher_permissions { edges = append(edges, user.EdgePublisherPermissions) } + if m.clearedreviews { + edges = append(edges, user.EdgeReviews) + } return edges } @@ -7601,6 +9679,8 @@ func (m *UserMutation) EdgeCleared(name string) bool { switch name { case user.EdgePublisherPermissions: return m.clearedpublisher_permissions + case user.EdgeReviews: + return m.clearedreviews } return false } @@ -7620,6 +9700,9 @@ func (m *UserMutation) ResetEdge(name string) error { case user.EdgePublisherPermissions: m.ResetPublisherPermissions() return nil + case user.EdgeReviews: + m.ResetReviews() + return nil } return fmt.Errorf("unknown User edge %s", name) } diff --git a/ent/node.go b/ent/node.go index 00f0a94..7fd2634 100644 --- a/ent/node.go +++ b/ent/node.go @@ -7,6 +7,7 @@ import ( "fmt" "registry-backend/ent/node" "registry-backend/ent/publisher" + "registry-backend/ent/schema" "strings" "time" @@ -29,6 +30,8 @@ type Node struct { Name string `json:"name,omitempty"` // Description holds the value of the "description" field. Description string `json:"description,omitempty"` + // Category holds the value of the "category" field. + Category string `json:"category,omitempty"` // Author holds the value of the "author" field. Author string `json:"author,omitempty"` // License holds the value of the "license" field. @@ -39,6 +42,16 @@ type Node struct { IconURL string `json:"icon_url,omitempty"` // Tags holds the value of the "tags" field. Tags []string `json:"tags,omitempty"` + // TotalInstall holds the value of the "total_install" field. + TotalInstall int64 `json:"total_install,omitempty"` + // TotalStar holds the value of the "total_star" field. + TotalStar int64 `json:"total_star,omitempty"` + // TotalReview holds the value of the "total_review" field. + TotalReview int64 `json:"total_review,omitempty"` + // Status holds the value of the "status" field. + Status schema.NodeStatus `json:"status,omitempty"` + // StatusDetail holds the value of the "status_detail" field. + StatusDetail string `json:"status_detail,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the NodeQuery when eager-loading is set. Edges NodeEdges `json:"edges"` @@ -51,9 +64,11 @@ type NodeEdges struct { Publisher *Publisher `json:"publisher,omitempty"` // Versions holds the value of the versions edge. Versions []*NodeVersion `json:"versions,omitempty"` + // Reviews holds the value of the reviews edge. + Reviews []*NodeReview `json:"reviews,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [2]bool + loadedTypes [3]bool } // PublisherOrErr returns the Publisher value or an error if the edge @@ -76,6 +91,15 @@ func (e NodeEdges) VersionsOrErr() ([]*NodeVersion, error) { return nil, &NotLoadedError{edge: "versions"} } +// ReviewsOrErr returns the Reviews value or an error if the edge +// was not loaded in eager-loading. +func (e NodeEdges) ReviewsOrErr() ([]*NodeReview, error) { + if e.loadedTypes[2] { + return e.Reviews, nil + } + return nil, &NotLoadedError{edge: "reviews"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*Node) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -83,7 +107,9 @@ func (*Node) scanValues(columns []string) ([]any, error) { switch columns[i] { case node.FieldTags: values[i] = new([]byte) - case node.FieldID, node.FieldPublisherID, node.FieldName, node.FieldDescription, node.FieldAuthor, node.FieldLicense, node.FieldRepositoryURL, node.FieldIconURL: + case node.FieldTotalInstall, node.FieldTotalStar, node.FieldTotalReview: + values[i] = new(sql.NullInt64) + case node.FieldID, node.FieldPublisherID, node.FieldName, node.FieldDescription, node.FieldCategory, node.FieldAuthor, node.FieldLicense, node.FieldRepositoryURL, node.FieldIconURL, node.FieldStatus, node.FieldStatusDetail: values[i] = new(sql.NullString) case node.FieldCreateTime, node.FieldUpdateTime: values[i] = new(sql.NullTime) @@ -138,6 +164,12 @@ func (n *Node) assignValues(columns []string, values []any) error { } else if value.Valid { n.Description = value.String } + case node.FieldCategory: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field category", values[i]) + } else if value.Valid { + n.Category = value.String + } case node.FieldAuthor: if value, ok := values[i].(*sql.NullString); !ok { return fmt.Errorf("unexpected type %T for field author", values[i]) @@ -170,6 +202,36 @@ func (n *Node) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field tags: %w", err) } } + case node.FieldTotalInstall: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field total_install", values[i]) + } else if value.Valid { + n.TotalInstall = value.Int64 + } + case node.FieldTotalStar: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field total_star", values[i]) + } else if value.Valid { + n.TotalStar = value.Int64 + } + case node.FieldTotalReview: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field total_review", values[i]) + } else if value.Valid { + n.TotalReview = value.Int64 + } + case node.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + n.Status = schema.NodeStatus(value.String) + } + case node.FieldStatusDetail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status_detail", values[i]) + } else if value.Valid { + n.StatusDetail = value.String + } default: n.selectValues.Set(columns[i], values[i]) } @@ -193,6 +255,11 @@ func (n *Node) QueryVersions() *NodeVersionQuery { return NewNodeClient(n.config).QueryVersions(n) } +// QueryReviews queries the "reviews" edge of the Node entity. +func (n *Node) QueryReviews() *NodeReviewQuery { + return NewNodeClient(n.config).QueryReviews(n) +} + // Update returns a builder for updating this Node. // Note that you need to call Node.Unwrap() before calling this method if this Node // was returned from a transaction, and the transaction was committed or rolled back. @@ -231,6 +298,9 @@ func (n *Node) String() string { builder.WriteString("description=") builder.WriteString(n.Description) builder.WriteString(", ") + builder.WriteString("category=") + builder.WriteString(n.Category) + builder.WriteString(", ") builder.WriteString("author=") builder.WriteString(n.Author) builder.WriteString(", ") @@ -245,6 +315,21 @@ func (n *Node) String() string { builder.WriteString(", ") builder.WriteString("tags=") builder.WriteString(fmt.Sprintf("%v", n.Tags)) + builder.WriteString(", ") + builder.WriteString("total_install=") + builder.WriteString(fmt.Sprintf("%v", n.TotalInstall)) + builder.WriteString(", ") + builder.WriteString("total_star=") + builder.WriteString(fmt.Sprintf("%v", n.TotalStar)) + builder.WriteString(", ") + builder.WriteString("total_review=") + builder.WriteString(fmt.Sprintf("%v", n.TotalReview)) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", n.Status)) + builder.WriteString(", ") + builder.WriteString("status_detail=") + builder.WriteString(n.StatusDetail) builder.WriteByte(')') return builder.String() } diff --git a/ent/node/node.go b/ent/node/node.go index 2fe6325..b835a98 100644 --- a/ent/node/node.go +++ b/ent/node/node.go @@ -3,6 +3,8 @@ package node import ( + "fmt" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -24,6 +26,8 @@ const ( FieldName = "name" // FieldDescription holds the string denoting the description field in the database. FieldDescription = "description" + // FieldCategory holds the string denoting the category field in the database. + FieldCategory = "category" // FieldAuthor holds the string denoting the author field in the database. FieldAuthor = "author" // FieldLicense holds the string denoting the license field in the database. @@ -34,10 +38,22 @@ const ( FieldIconURL = "icon_url" // FieldTags holds the string denoting the tags field in the database. FieldTags = "tags" + // FieldTotalInstall holds the string denoting the total_install field in the database. + FieldTotalInstall = "total_install" + // FieldTotalStar holds the string denoting the total_star field in the database. + FieldTotalStar = "total_star" + // FieldTotalReview holds the string denoting the total_review field in the database. + FieldTotalReview = "total_review" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldStatusDetail holds the string denoting the status_detail field in the database. + FieldStatusDetail = "status_detail" // EdgePublisher holds the string denoting the publisher edge name in mutations. EdgePublisher = "publisher" // EdgeVersions holds the string denoting the versions edge name in mutations. EdgeVersions = "versions" + // EdgeReviews holds the string denoting the reviews edge name in mutations. + EdgeReviews = "reviews" // Table holds the table name of the node in the database. Table = "nodes" // PublisherTable is the table that holds the publisher relation/edge. @@ -54,6 +70,13 @@ const ( VersionsInverseTable = "node_versions" // VersionsColumn is the table column denoting the versions relation/edge. VersionsColumn = "node_id" + // ReviewsTable is the table that holds the reviews relation/edge. + ReviewsTable = "node_reviews" + // ReviewsInverseTable is the table name for the NodeReview entity. + // It exists in this package in order to avoid circular dependency with the "nodereview" package. + ReviewsInverseTable = "node_reviews" + // ReviewsColumn is the table column denoting the reviews relation/edge. + ReviewsColumn = "node_id" ) // Columns holds all SQL columns for node fields. @@ -64,11 +87,17 @@ var Columns = []string{ FieldPublisherID, FieldName, FieldDescription, + FieldCategory, FieldAuthor, FieldLicense, FieldRepositoryURL, FieldIconURL, FieldTags, + FieldTotalInstall, + FieldTotalStar, + FieldTotalReview, + FieldStatus, + FieldStatusDetail, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -90,8 +119,26 @@ var ( UpdateDefaultUpdateTime func() time.Time // DefaultTags holds the default value on creation for the "tags" field. DefaultTags []string + // DefaultTotalInstall holds the default value on creation for the "total_install" field. + DefaultTotalInstall int64 + // DefaultTotalStar holds the default value on creation for the "total_star" field. + DefaultTotalStar int64 + // DefaultTotalReview holds the default value on creation for the "total_review" field. + DefaultTotalReview int64 ) +const DefaultStatus schema.NodeStatus = "active" + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s schema.NodeStatus) error { + switch s { + case "active", "banned", "deleted": + return nil + default: + return fmt.Errorf("node: invalid enum value for status field: %q", s) + } +} + // OrderOption defines the ordering options for the Node queries. type OrderOption func(*sql.Selector) @@ -125,6 +172,11 @@ func ByDescription(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDescription, opts...).ToFunc() } +// ByCategory orders the results by the category field. +func ByCategory(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCategory, opts...).ToFunc() +} + // ByAuthor orders the results by the author field. func ByAuthor(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldAuthor, opts...).ToFunc() @@ -145,6 +197,31 @@ func ByIconURL(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldIconURL, opts...).ToFunc() } +// ByTotalInstall orders the results by the total_install field. +func ByTotalInstall(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalInstall, opts...).ToFunc() +} + +// ByTotalStar orders the results by the total_star field. +func ByTotalStar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalStar, opts...).ToFunc() +} + +// ByTotalReview orders the results by the total_review field. +func ByTotalReview(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTotalReview, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByStatusDetail orders the results by the status_detail field. +func ByStatusDetail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatusDetail, opts...).ToFunc() +} + // ByPublisherField orders the results by publisher field. func ByPublisherField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -165,6 +242,20 @@ func ByVersions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { sqlgraph.OrderByNeighborTerms(s, newVersionsStep(), append([]sql.OrderTerm{term}, terms...)...) } } + +// ByReviewsCount orders the results by reviews count. +func ByReviewsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newReviewsStep(), opts...) + } +} + +// ByReviews orders the results by reviews terms. +func ByReviews(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newReviewsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newPublisherStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -179,3 +270,10 @@ func newVersionsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, VersionsTable, VersionsColumn), ) } +func newReviewsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ReviewsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReviewsTable, ReviewsColumn), + ) +} diff --git a/ent/node/where.go b/ent/node/where.go index 6260d08..3531dc0 100644 --- a/ent/node/where.go +++ b/ent/node/where.go @@ -4,6 +4,7 @@ package node import ( "registry-backend/ent/predicate" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -90,6 +91,11 @@ func Description(v string) predicate.Node { return predicate.Node(sql.FieldEQ(FieldDescription, v)) } +// Category applies equality check predicate on the "category" field. It's identical to CategoryEQ. +func Category(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldCategory, v)) +} + // Author applies equality check predicate on the "author" field. It's identical to AuthorEQ. func Author(v string) predicate.Node { return predicate.Node(sql.FieldEQ(FieldAuthor, v)) @@ -110,6 +116,26 @@ func IconURL(v string) predicate.Node { return predicate.Node(sql.FieldEQ(FieldIconURL, v)) } +// TotalInstall applies equality check predicate on the "total_install" field. It's identical to TotalInstallEQ. +func TotalInstall(v int64) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldTotalInstall, v)) +} + +// TotalStar applies equality check predicate on the "total_star" field. It's identical to TotalStarEQ. +func TotalStar(v int64) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldTotalStar, v)) +} + +// TotalReview applies equality check predicate on the "total_review" field. It's identical to TotalReviewEQ. +func TotalReview(v int64) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldTotalReview, v)) +} + +// StatusDetail applies equality check predicate on the "status_detail" field. It's identical to StatusDetailEQ. +func StatusDetail(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldStatusDetail, v)) +} + // CreateTimeEQ applies the EQ predicate on the "create_time" field. func CreateTimeEQ(v time.Time) predicate.Node { return predicate.Node(sql.FieldEQ(FieldCreateTime, v)) @@ -395,6 +421,81 @@ func DescriptionContainsFold(v string) predicate.Node { return predicate.Node(sql.FieldContainsFold(FieldDescription, v)) } +// CategoryEQ applies the EQ predicate on the "category" field. +func CategoryEQ(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldCategory, v)) +} + +// CategoryNEQ applies the NEQ predicate on the "category" field. +func CategoryNEQ(v string) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldCategory, v)) +} + +// CategoryIn applies the In predicate on the "category" field. +func CategoryIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldIn(FieldCategory, vs...)) +} + +// CategoryNotIn applies the NotIn predicate on the "category" field. +func CategoryNotIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldCategory, vs...)) +} + +// CategoryGT applies the GT predicate on the "category" field. +func CategoryGT(v string) predicate.Node { + return predicate.Node(sql.FieldGT(FieldCategory, v)) +} + +// CategoryGTE applies the GTE predicate on the "category" field. +func CategoryGTE(v string) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldCategory, v)) +} + +// CategoryLT applies the LT predicate on the "category" field. +func CategoryLT(v string) predicate.Node { + return predicate.Node(sql.FieldLT(FieldCategory, v)) +} + +// CategoryLTE applies the LTE predicate on the "category" field. +func CategoryLTE(v string) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldCategory, v)) +} + +// CategoryContains applies the Contains predicate on the "category" field. +func CategoryContains(v string) predicate.Node { + return predicate.Node(sql.FieldContains(FieldCategory, v)) +} + +// CategoryHasPrefix applies the HasPrefix predicate on the "category" field. +func CategoryHasPrefix(v string) predicate.Node { + return predicate.Node(sql.FieldHasPrefix(FieldCategory, v)) +} + +// CategoryHasSuffix applies the HasSuffix predicate on the "category" field. +func CategoryHasSuffix(v string) predicate.Node { + return predicate.Node(sql.FieldHasSuffix(FieldCategory, v)) +} + +// CategoryIsNil applies the IsNil predicate on the "category" field. +func CategoryIsNil() predicate.Node { + return predicate.Node(sql.FieldIsNull(FieldCategory)) +} + +// CategoryNotNil applies the NotNil predicate on the "category" field. +func CategoryNotNil() predicate.Node { + return predicate.Node(sql.FieldNotNull(FieldCategory)) +} + +// CategoryEqualFold applies the EqualFold predicate on the "category" field. +func CategoryEqualFold(v string) predicate.Node { + return predicate.Node(sql.FieldEqualFold(FieldCategory, v)) +} + +// CategoryContainsFold applies the ContainsFold predicate on the "category" field. +func CategoryContainsFold(v string) predicate.Node { + return predicate.Node(sql.FieldContainsFold(FieldCategory, v)) +} + // AuthorEQ applies the EQ predicate on the "author" field. func AuthorEQ(v string) predicate.Node { return predicate.Node(sql.FieldEQ(FieldAuthor, v)) @@ -675,6 +776,231 @@ func IconURLContainsFold(v string) predicate.Node { return predicate.Node(sql.FieldContainsFold(FieldIconURL, v)) } +// TotalInstallEQ applies the EQ predicate on the "total_install" field. +func TotalInstallEQ(v int64) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldTotalInstall, v)) +} + +// TotalInstallNEQ applies the NEQ predicate on the "total_install" field. +func TotalInstallNEQ(v int64) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldTotalInstall, v)) +} + +// TotalInstallIn applies the In predicate on the "total_install" field. +func TotalInstallIn(vs ...int64) predicate.Node { + return predicate.Node(sql.FieldIn(FieldTotalInstall, vs...)) +} + +// TotalInstallNotIn applies the NotIn predicate on the "total_install" field. +func TotalInstallNotIn(vs ...int64) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldTotalInstall, vs...)) +} + +// TotalInstallGT applies the GT predicate on the "total_install" field. +func TotalInstallGT(v int64) predicate.Node { + return predicate.Node(sql.FieldGT(FieldTotalInstall, v)) +} + +// TotalInstallGTE applies the GTE predicate on the "total_install" field. +func TotalInstallGTE(v int64) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldTotalInstall, v)) +} + +// TotalInstallLT applies the LT predicate on the "total_install" field. +func TotalInstallLT(v int64) predicate.Node { + return predicate.Node(sql.FieldLT(FieldTotalInstall, v)) +} + +// TotalInstallLTE applies the LTE predicate on the "total_install" field. +func TotalInstallLTE(v int64) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldTotalInstall, v)) +} + +// TotalStarEQ applies the EQ predicate on the "total_star" field. +func TotalStarEQ(v int64) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldTotalStar, v)) +} + +// TotalStarNEQ applies the NEQ predicate on the "total_star" field. +func TotalStarNEQ(v int64) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldTotalStar, v)) +} + +// TotalStarIn applies the In predicate on the "total_star" field. +func TotalStarIn(vs ...int64) predicate.Node { + return predicate.Node(sql.FieldIn(FieldTotalStar, vs...)) +} + +// TotalStarNotIn applies the NotIn predicate on the "total_star" field. +func TotalStarNotIn(vs ...int64) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldTotalStar, vs...)) +} + +// TotalStarGT applies the GT predicate on the "total_star" field. +func TotalStarGT(v int64) predicate.Node { + return predicate.Node(sql.FieldGT(FieldTotalStar, v)) +} + +// TotalStarGTE applies the GTE predicate on the "total_star" field. +func TotalStarGTE(v int64) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldTotalStar, v)) +} + +// TotalStarLT applies the LT predicate on the "total_star" field. +func TotalStarLT(v int64) predicate.Node { + return predicate.Node(sql.FieldLT(FieldTotalStar, v)) +} + +// TotalStarLTE applies the LTE predicate on the "total_star" field. +func TotalStarLTE(v int64) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldTotalStar, v)) +} + +// TotalReviewEQ applies the EQ predicate on the "total_review" field. +func TotalReviewEQ(v int64) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldTotalReview, v)) +} + +// TotalReviewNEQ applies the NEQ predicate on the "total_review" field. +func TotalReviewNEQ(v int64) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldTotalReview, v)) +} + +// TotalReviewIn applies the In predicate on the "total_review" field. +func TotalReviewIn(vs ...int64) predicate.Node { + return predicate.Node(sql.FieldIn(FieldTotalReview, vs...)) +} + +// TotalReviewNotIn applies the NotIn predicate on the "total_review" field. +func TotalReviewNotIn(vs ...int64) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldTotalReview, vs...)) +} + +// TotalReviewGT applies the GT predicate on the "total_review" field. +func TotalReviewGT(v int64) predicate.Node { + return predicate.Node(sql.FieldGT(FieldTotalReview, v)) +} + +// TotalReviewGTE applies the GTE predicate on the "total_review" field. +func TotalReviewGTE(v int64) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldTotalReview, v)) +} + +// TotalReviewLT applies the LT predicate on the "total_review" field. +func TotalReviewLT(v int64) predicate.Node { + return predicate.Node(sql.FieldLT(FieldTotalReview, v)) +} + +// TotalReviewLTE applies the LTE predicate on the "total_review" field. +func TotalReviewLTE(v int64) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldTotalReview, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v schema.NodeStatus) predicate.Node { + vc := v + return predicate.Node(sql.FieldEQ(FieldStatus, vc)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v schema.NodeStatus) predicate.Node { + vc := v + return predicate.Node(sql.FieldNEQ(FieldStatus, vc)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...schema.NodeStatus) predicate.Node { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Node(sql.FieldIn(FieldStatus, v...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...schema.NodeStatus) predicate.Node { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Node(sql.FieldNotIn(FieldStatus, v...)) +} + +// StatusDetailEQ applies the EQ predicate on the "status_detail" field. +func StatusDetailEQ(v string) predicate.Node { + return predicate.Node(sql.FieldEQ(FieldStatusDetail, v)) +} + +// StatusDetailNEQ applies the NEQ predicate on the "status_detail" field. +func StatusDetailNEQ(v string) predicate.Node { + return predicate.Node(sql.FieldNEQ(FieldStatusDetail, v)) +} + +// StatusDetailIn applies the In predicate on the "status_detail" field. +func StatusDetailIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldIn(FieldStatusDetail, vs...)) +} + +// StatusDetailNotIn applies the NotIn predicate on the "status_detail" field. +func StatusDetailNotIn(vs ...string) predicate.Node { + return predicate.Node(sql.FieldNotIn(FieldStatusDetail, vs...)) +} + +// StatusDetailGT applies the GT predicate on the "status_detail" field. +func StatusDetailGT(v string) predicate.Node { + return predicate.Node(sql.FieldGT(FieldStatusDetail, v)) +} + +// StatusDetailGTE applies the GTE predicate on the "status_detail" field. +func StatusDetailGTE(v string) predicate.Node { + return predicate.Node(sql.FieldGTE(FieldStatusDetail, v)) +} + +// StatusDetailLT applies the LT predicate on the "status_detail" field. +func StatusDetailLT(v string) predicate.Node { + return predicate.Node(sql.FieldLT(FieldStatusDetail, v)) +} + +// StatusDetailLTE applies the LTE predicate on the "status_detail" field. +func StatusDetailLTE(v string) predicate.Node { + return predicate.Node(sql.FieldLTE(FieldStatusDetail, v)) +} + +// StatusDetailContains applies the Contains predicate on the "status_detail" field. +func StatusDetailContains(v string) predicate.Node { + return predicate.Node(sql.FieldContains(FieldStatusDetail, v)) +} + +// StatusDetailHasPrefix applies the HasPrefix predicate on the "status_detail" field. +func StatusDetailHasPrefix(v string) predicate.Node { + return predicate.Node(sql.FieldHasPrefix(FieldStatusDetail, v)) +} + +// StatusDetailHasSuffix applies the HasSuffix predicate on the "status_detail" field. +func StatusDetailHasSuffix(v string) predicate.Node { + return predicate.Node(sql.FieldHasSuffix(FieldStatusDetail, v)) +} + +// StatusDetailIsNil applies the IsNil predicate on the "status_detail" field. +func StatusDetailIsNil() predicate.Node { + return predicate.Node(sql.FieldIsNull(FieldStatusDetail)) +} + +// StatusDetailNotNil applies the NotNil predicate on the "status_detail" field. +func StatusDetailNotNil() predicate.Node { + return predicate.Node(sql.FieldNotNull(FieldStatusDetail)) +} + +// StatusDetailEqualFold applies the EqualFold predicate on the "status_detail" field. +func StatusDetailEqualFold(v string) predicate.Node { + return predicate.Node(sql.FieldEqualFold(FieldStatusDetail, v)) +} + +// StatusDetailContainsFold applies the ContainsFold predicate on the "status_detail" field. +func StatusDetailContainsFold(v string) predicate.Node { + return predicate.Node(sql.FieldContainsFold(FieldStatusDetail, v)) +} + // HasPublisher applies the HasEdge predicate on the "publisher" edge. func HasPublisher() predicate.Node { return predicate.Node(func(s *sql.Selector) { @@ -721,6 +1047,29 @@ func HasVersionsWith(preds ...predicate.NodeVersion) predicate.Node { }) } +// HasReviews applies the HasEdge predicate on the "reviews" edge. +func HasReviews() predicate.Node { + return predicate.Node(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReviewsTable, ReviewsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasReviewsWith applies the HasEdge predicate on the "reviews" edge with a given conditions (other predicates). +func HasReviewsWith(preds ...predicate.NodeReview) predicate.Node { + return predicate.Node(func(s *sql.Selector) { + step := newReviewsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.Node) predicate.Node { return predicate.Node(sql.AndPredicates(predicates...)) diff --git a/ent/node_create.go b/ent/node_create.go index e413ea8..06f8b57 100644 --- a/ent/node_create.go +++ b/ent/node_create.go @@ -7,8 +7,10 @@ import ( "errors" "fmt" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/publisher" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect" @@ -80,6 +82,20 @@ func (nc *NodeCreate) SetNillableDescription(s *string) *NodeCreate { return nc } +// SetCategory sets the "category" field. +func (nc *NodeCreate) SetCategory(s string) *NodeCreate { + nc.mutation.SetCategory(s) + return nc +} + +// SetNillableCategory sets the "category" field if the given value is not nil. +func (nc *NodeCreate) SetNillableCategory(s *string) *NodeCreate { + if s != nil { + nc.SetCategory(*s) + } + return nc +} + // SetAuthor sets the "author" field. func (nc *NodeCreate) SetAuthor(s string) *NodeCreate { nc.mutation.SetAuthor(s) @@ -126,6 +142,76 @@ func (nc *NodeCreate) SetTags(s []string) *NodeCreate { return nc } +// SetTotalInstall sets the "total_install" field. +func (nc *NodeCreate) SetTotalInstall(i int64) *NodeCreate { + nc.mutation.SetTotalInstall(i) + return nc +} + +// SetNillableTotalInstall sets the "total_install" field if the given value is not nil. +func (nc *NodeCreate) SetNillableTotalInstall(i *int64) *NodeCreate { + if i != nil { + nc.SetTotalInstall(*i) + } + return nc +} + +// SetTotalStar sets the "total_star" field. +func (nc *NodeCreate) SetTotalStar(i int64) *NodeCreate { + nc.mutation.SetTotalStar(i) + return nc +} + +// SetNillableTotalStar sets the "total_star" field if the given value is not nil. +func (nc *NodeCreate) SetNillableTotalStar(i *int64) *NodeCreate { + if i != nil { + nc.SetTotalStar(*i) + } + return nc +} + +// SetTotalReview sets the "total_review" field. +func (nc *NodeCreate) SetTotalReview(i int64) *NodeCreate { + nc.mutation.SetTotalReview(i) + return nc +} + +// SetNillableTotalReview sets the "total_review" field if the given value is not nil. +func (nc *NodeCreate) SetNillableTotalReview(i *int64) *NodeCreate { + if i != nil { + nc.SetTotalReview(*i) + } + return nc +} + +// SetStatus sets the "status" field. +func (nc *NodeCreate) SetStatus(ss schema.NodeStatus) *NodeCreate { + nc.mutation.SetStatus(ss) + return nc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nc *NodeCreate) SetNillableStatus(ss *schema.NodeStatus) *NodeCreate { + if ss != nil { + nc.SetStatus(*ss) + } + return nc +} + +// SetStatusDetail sets the "status_detail" field. +func (nc *NodeCreate) SetStatusDetail(s string) *NodeCreate { + nc.mutation.SetStatusDetail(s) + return nc +} + +// SetNillableStatusDetail sets the "status_detail" field if the given value is not nil. +func (nc *NodeCreate) SetNillableStatusDetail(s *string) *NodeCreate { + if s != nil { + nc.SetStatusDetail(*s) + } + return nc +} + // SetID sets the "id" field. func (nc *NodeCreate) SetID(s string) *NodeCreate { nc.mutation.SetID(s) @@ -152,6 +238,21 @@ func (nc *NodeCreate) AddVersions(n ...*NodeVersion) *NodeCreate { return nc.AddVersionIDs(ids...) } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by IDs. +func (nc *NodeCreate) AddReviewIDs(ids ...uuid.UUID) *NodeCreate { + nc.mutation.AddReviewIDs(ids...) + return nc +} + +// AddReviews adds the "reviews" edges to the NodeReview entity. +func (nc *NodeCreate) AddReviews(n ...*NodeReview) *NodeCreate { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return nc.AddReviewIDs(ids...) +} + // Mutation returns the NodeMutation object of the builder. func (nc *NodeCreate) Mutation() *NodeMutation { return nc.mutation @@ -199,6 +300,22 @@ func (nc *NodeCreate) defaults() { v := node.DefaultTags nc.mutation.SetTags(v) } + if _, ok := nc.mutation.TotalInstall(); !ok { + v := node.DefaultTotalInstall + nc.mutation.SetTotalInstall(v) + } + if _, ok := nc.mutation.TotalStar(); !ok { + v := node.DefaultTotalStar + nc.mutation.SetTotalStar(v) + } + if _, ok := nc.mutation.TotalReview(); !ok { + v := node.DefaultTotalReview + nc.mutation.SetTotalReview(v) + } + if _, ok := nc.mutation.Status(); !ok { + v := node.DefaultStatus + nc.mutation.SetStatus(v) + } } // check runs all checks and user-defined validators on the builder. @@ -224,6 +341,23 @@ func (nc *NodeCreate) check() error { if _, ok := nc.mutation.Tags(); !ok { return &ValidationError{Name: "tags", err: errors.New(`ent: missing required field "Node.tags"`)} } + if _, ok := nc.mutation.TotalInstall(); !ok { + return &ValidationError{Name: "total_install", err: errors.New(`ent: missing required field "Node.total_install"`)} + } + if _, ok := nc.mutation.TotalStar(); !ok { + return &ValidationError{Name: "total_star", err: errors.New(`ent: missing required field "Node.total_star"`)} + } + if _, ok := nc.mutation.TotalReview(); !ok { + return &ValidationError{Name: "total_review", err: errors.New(`ent: missing required field "Node.total_review"`)} + } + if _, ok := nc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Node.status"`)} + } + if v, ok := nc.mutation.Status(); ok { + if err := node.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Node.status": %w`, err)} + } + } if _, ok := nc.mutation.PublisherID(); !ok { return &ValidationError{Name: "publisher", err: errors.New(`ent: missing required edge "Node.publisher"`)} } @@ -279,6 +413,10 @@ func (nc *NodeCreate) createSpec() (*Node, *sqlgraph.CreateSpec) { _spec.SetField(node.FieldDescription, field.TypeString, value) _node.Description = value } + if value, ok := nc.mutation.Category(); ok { + _spec.SetField(node.FieldCategory, field.TypeString, value) + _node.Category = value + } if value, ok := nc.mutation.Author(); ok { _spec.SetField(node.FieldAuthor, field.TypeString, value) _node.Author = value @@ -299,6 +437,26 @@ func (nc *NodeCreate) createSpec() (*Node, *sqlgraph.CreateSpec) { _spec.SetField(node.FieldTags, field.TypeJSON, value) _node.Tags = value } + if value, ok := nc.mutation.TotalInstall(); ok { + _spec.SetField(node.FieldTotalInstall, field.TypeInt64, value) + _node.TotalInstall = value + } + if value, ok := nc.mutation.TotalStar(); ok { + _spec.SetField(node.FieldTotalStar, field.TypeInt64, value) + _node.TotalStar = value + } + if value, ok := nc.mutation.TotalReview(); ok { + _spec.SetField(node.FieldTotalReview, field.TypeInt64, value) + _node.TotalReview = value + } + if value, ok := nc.mutation.Status(); ok { + _spec.SetField(node.FieldStatus, field.TypeEnum, value) + _node.Status = value + } + if value, ok := nc.mutation.StatusDetail(); ok { + _spec.SetField(node.FieldStatusDetail, field.TypeString, value) + _node.StatusDetail = value + } if nodes := nc.mutation.PublisherIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -332,6 +490,22 @@ func (nc *NodeCreate) createSpec() (*Node, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := nc.mutation.ReviewsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -438,6 +612,24 @@ func (u *NodeUpsert) ClearDescription() *NodeUpsert { return u } +// SetCategory sets the "category" field. +func (u *NodeUpsert) SetCategory(v string) *NodeUpsert { + u.Set(node.FieldCategory, v) + return u +} + +// UpdateCategory sets the "category" field to the value that was provided on create. +func (u *NodeUpsert) UpdateCategory() *NodeUpsert { + u.SetExcluded(node.FieldCategory) + return u +} + +// ClearCategory clears the value of the "category" field. +func (u *NodeUpsert) ClearCategory() *NodeUpsert { + u.SetNull(node.FieldCategory) + return u +} + // SetAuthor sets the "author" field. func (u *NodeUpsert) SetAuthor(v string) *NodeUpsert { u.Set(node.FieldAuthor, v) @@ -510,6 +702,90 @@ func (u *NodeUpsert) UpdateTags() *NodeUpsert { return u } +// SetTotalInstall sets the "total_install" field. +func (u *NodeUpsert) SetTotalInstall(v int64) *NodeUpsert { + u.Set(node.FieldTotalInstall, v) + return u +} + +// UpdateTotalInstall sets the "total_install" field to the value that was provided on create. +func (u *NodeUpsert) UpdateTotalInstall() *NodeUpsert { + u.SetExcluded(node.FieldTotalInstall) + return u +} + +// AddTotalInstall adds v to the "total_install" field. +func (u *NodeUpsert) AddTotalInstall(v int64) *NodeUpsert { + u.Add(node.FieldTotalInstall, v) + return u +} + +// SetTotalStar sets the "total_star" field. +func (u *NodeUpsert) SetTotalStar(v int64) *NodeUpsert { + u.Set(node.FieldTotalStar, v) + return u +} + +// UpdateTotalStar sets the "total_star" field to the value that was provided on create. +func (u *NodeUpsert) UpdateTotalStar() *NodeUpsert { + u.SetExcluded(node.FieldTotalStar) + return u +} + +// AddTotalStar adds v to the "total_star" field. +func (u *NodeUpsert) AddTotalStar(v int64) *NodeUpsert { + u.Add(node.FieldTotalStar, v) + return u +} + +// SetTotalReview sets the "total_review" field. +func (u *NodeUpsert) SetTotalReview(v int64) *NodeUpsert { + u.Set(node.FieldTotalReview, v) + return u +} + +// UpdateTotalReview sets the "total_review" field to the value that was provided on create. +func (u *NodeUpsert) UpdateTotalReview() *NodeUpsert { + u.SetExcluded(node.FieldTotalReview) + return u +} + +// AddTotalReview adds v to the "total_review" field. +func (u *NodeUpsert) AddTotalReview(v int64) *NodeUpsert { + u.Add(node.FieldTotalReview, v) + return u +} + +// SetStatus sets the "status" field. +func (u *NodeUpsert) SetStatus(v schema.NodeStatus) *NodeUpsert { + u.Set(node.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeUpsert) UpdateStatus() *NodeUpsert { + u.SetExcluded(node.FieldStatus) + return u +} + +// SetStatusDetail sets the "status_detail" field. +func (u *NodeUpsert) SetStatusDetail(v string) *NodeUpsert { + u.Set(node.FieldStatusDetail, v) + return u +} + +// UpdateStatusDetail sets the "status_detail" field to the value that was provided on create. +func (u *NodeUpsert) UpdateStatusDetail() *NodeUpsert { + u.SetExcluded(node.FieldStatusDetail) + return u +} + +// ClearStatusDetail clears the value of the "status_detail" field. +func (u *NodeUpsert) ClearStatusDetail() *NodeUpsert { + u.SetNull(node.FieldStatusDetail) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -624,6 +900,27 @@ func (u *NodeUpsertOne) ClearDescription() *NodeUpsertOne { }) } +// SetCategory sets the "category" field. +func (u *NodeUpsertOne) SetCategory(v string) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetCategory(v) + }) +} + +// UpdateCategory sets the "category" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateCategory() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateCategory() + }) +} + +// ClearCategory clears the value of the "category" field. +func (u *NodeUpsertOne) ClearCategory() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.ClearCategory() + }) +} + // SetAuthor sets the "author" field. func (u *NodeUpsertOne) SetAuthor(v string) *NodeUpsertOne { return u.Update(func(s *NodeUpsert) { @@ -708,6 +1005,104 @@ func (u *NodeUpsertOne) UpdateTags() *NodeUpsertOne { }) } +// SetTotalInstall sets the "total_install" field. +func (u *NodeUpsertOne) SetTotalInstall(v int64) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetTotalInstall(v) + }) +} + +// AddTotalInstall adds v to the "total_install" field. +func (u *NodeUpsertOne) AddTotalInstall(v int64) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.AddTotalInstall(v) + }) +} + +// UpdateTotalInstall sets the "total_install" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateTotalInstall() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateTotalInstall() + }) +} + +// SetTotalStar sets the "total_star" field. +func (u *NodeUpsertOne) SetTotalStar(v int64) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetTotalStar(v) + }) +} + +// AddTotalStar adds v to the "total_star" field. +func (u *NodeUpsertOne) AddTotalStar(v int64) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.AddTotalStar(v) + }) +} + +// UpdateTotalStar sets the "total_star" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateTotalStar() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateTotalStar() + }) +} + +// SetTotalReview sets the "total_review" field. +func (u *NodeUpsertOne) SetTotalReview(v int64) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetTotalReview(v) + }) +} + +// AddTotalReview adds v to the "total_review" field. +func (u *NodeUpsertOne) AddTotalReview(v int64) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.AddTotalReview(v) + }) +} + +// UpdateTotalReview sets the "total_review" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateTotalReview() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateTotalReview() + }) +} + +// SetStatus sets the "status" field. +func (u *NodeUpsertOne) SetStatus(v schema.NodeStatus) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateStatus() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateStatus() + }) +} + +// SetStatusDetail sets the "status_detail" field. +func (u *NodeUpsertOne) SetStatusDetail(v string) *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.SetStatusDetail(v) + }) +} + +// UpdateStatusDetail sets the "status_detail" field to the value that was provided on create. +func (u *NodeUpsertOne) UpdateStatusDetail() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.UpdateStatusDetail() + }) +} + +// ClearStatusDetail clears the value of the "status_detail" field. +func (u *NodeUpsertOne) ClearStatusDetail() *NodeUpsertOne { + return u.Update(func(s *NodeUpsert) { + s.ClearStatusDetail() + }) +} + // Exec executes the query. func (u *NodeUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -989,6 +1384,27 @@ func (u *NodeUpsertBulk) ClearDescription() *NodeUpsertBulk { }) } +// SetCategory sets the "category" field. +func (u *NodeUpsertBulk) SetCategory(v string) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetCategory(v) + }) +} + +// UpdateCategory sets the "category" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateCategory() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateCategory() + }) +} + +// ClearCategory clears the value of the "category" field. +func (u *NodeUpsertBulk) ClearCategory() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.ClearCategory() + }) +} + // SetAuthor sets the "author" field. func (u *NodeUpsertBulk) SetAuthor(v string) *NodeUpsertBulk { return u.Update(func(s *NodeUpsert) { @@ -1073,6 +1489,104 @@ func (u *NodeUpsertBulk) UpdateTags() *NodeUpsertBulk { }) } +// SetTotalInstall sets the "total_install" field. +func (u *NodeUpsertBulk) SetTotalInstall(v int64) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetTotalInstall(v) + }) +} + +// AddTotalInstall adds v to the "total_install" field. +func (u *NodeUpsertBulk) AddTotalInstall(v int64) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.AddTotalInstall(v) + }) +} + +// UpdateTotalInstall sets the "total_install" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateTotalInstall() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateTotalInstall() + }) +} + +// SetTotalStar sets the "total_star" field. +func (u *NodeUpsertBulk) SetTotalStar(v int64) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetTotalStar(v) + }) +} + +// AddTotalStar adds v to the "total_star" field. +func (u *NodeUpsertBulk) AddTotalStar(v int64) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.AddTotalStar(v) + }) +} + +// UpdateTotalStar sets the "total_star" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateTotalStar() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateTotalStar() + }) +} + +// SetTotalReview sets the "total_review" field. +func (u *NodeUpsertBulk) SetTotalReview(v int64) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetTotalReview(v) + }) +} + +// AddTotalReview adds v to the "total_review" field. +func (u *NodeUpsertBulk) AddTotalReview(v int64) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.AddTotalReview(v) + }) +} + +// UpdateTotalReview sets the "total_review" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateTotalReview() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateTotalReview() + }) +} + +// SetStatus sets the "status" field. +func (u *NodeUpsertBulk) SetStatus(v schema.NodeStatus) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateStatus() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateStatus() + }) +} + +// SetStatusDetail sets the "status_detail" field. +func (u *NodeUpsertBulk) SetStatusDetail(v string) *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.SetStatusDetail(v) + }) +} + +// UpdateStatusDetail sets the "status_detail" field to the value that was provided on create. +func (u *NodeUpsertBulk) UpdateStatusDetail() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.UpdateStatusDetail() + }) +} + +// ClearStatusDetail clears the value of the "status_detail" field. +func (u *NodeUpsertBulk) ClearStatusDetail() *NodeUpsertBulk { + return u.Update(func(s *NodeUpsert) { + s.ClearStatusDetail() + }) +} + // Exec executes the query. func (u *NodeUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/ent/node_query.go b/ent/node_query.go index 89eec59..87dd563 100644 --- a/ent/node_query.go +++ b/ent/node_query.go @@ -8,6 +8,7 @@ import ( "fmt" "math" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/predicate" "registry-backend/ent/publisher" @@ -27,6 +28,7 @@ type NodeQuery struct { predicates []predicate.Node withPublisher *PublisherQuery withVersions *NodeVersionQuery + withReviews *NodeReviewQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector @@ -108,6 +110,28 @@ func (nq *NodeQuery) QueryVersions() *NodeVersionQuery { return query } +// QueryReviews chains the current query on the "reviews" edge. +func (nq *NodeQuery) QueryReviews() *NodeReviewQuery { + query := (&NodeReviewClient{config: nq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := nq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := nq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(node.Table, node.FieldID, selector), + sqlgraph.To(nodereview.Table, nodereview.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, node.ReviewsTable, node.ReviewsColumn), + ) + fromU = sqlgraph.SetNeighbors(nq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first Node entity from the query. // Returns a *NotFoundError when no Node was found. func (nq *NodeQuery) First(ctx context.Context) (*Node, error) { @@ -302,6 +326,7 @@ func (nq *NodeQuery) Clone() *NodeQuery { predicates: append([]predicate.Node{}, nq.predicates...), withPublisher: nq.withPublisher.Clone(), withVersions: nq.withVersions.Clone(), + withReviews: nq.withReviews.Clone(), // clone intermediate query. sql: nq.sql.Clone(), path: nq.path, @@ -330,6 +355,17 @@ func (nq *NodeQuery) WithVersions(opts ...func(*NodeVersionQuery)) *NodeQuery { return nq } +// WithReviews tells the query-builder to eager-load the nodes that are connected to +// the "reviews" edge. The optional arguments are used to configure the query builder of the edge. +func (nq *NodeQuery) WithReviews(opts ...func(*NodeReviewQuery)) *NodeQuery { + query := (&NodeReviewClient{config: nq.config}).Query() + for _, opt := range opts { + opt(query) + } + nq.withReviews = query + return nq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -408,9 +444,10 @@ func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, e var ( nodes = []*Node{} _spec = nq.querySpec() - loadedTypes = [2]bool{ + loadedTypes = [3]bool{ nq.withPublisher != nil, nq.withVersions != nil, + nq.withReviews != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -447,6 +484,13 @@ func (nq *NodeQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Node, e return nil, err } } + if query := nq.withReviews; query != nil { + if err := nq.loadReviews(ctx, query, nodes, + func(n *Node) { n.Edges.Reviews = []*NodeReview{} }, + func(n *Node, e *NodeReview) { n.Edges.Reviews = append(n.Edges.Reviews, e) }); err != nil { + return nil, err + } + } return nodes, nil } @@ -510,6 +554,36 @@ func (nq *NodeQuery) loadVersions(ctx context.Context, query *NodeVersionQuery, } return nil } +func (nq *NodeQuery) loadReviews(ctx context.Context, query *NodeReviewQuery, nodes []*Node, init func(*Node), assign func(*Node, *NodeReview)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[string]*Node) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(nodereview.FieldNodeID) + } + query.Where(predicate.NodeReview(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(node.ReviewsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.NodeID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "node_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (nq *NodeQuery) sqlCount(ctx context.Context) (int, error) { _spec := nq.querySpec() @@ -627,6 +701,12 @@ func (nq *NodeQuery) ForShare(opts ...sql.LockOption) *NodeQuery { return nq } +// Modify adds a query modifier for attaching custom logic to queries. +func (nq *NodeQuery) Modify(modifiers ...func(s *sql.Selector)) *NodeSelect { + nq.modifiers = append(nq.modifiers, modifiers...) + return nq.Select() +} + // NodeGroupBy is the group-by builder for Node entities. type NodeGroupBy struct { selector @@ -716,3 +796,9 @@ func (ns *NodeSelect) sqlScan(ctx context.Context, root *NodeQuery, v any) error defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (ns *NodeSelect) Modify(modifiers ...func(s *sql.Selector)) *NodeSelect { + ns.modifiers = append(ns.modifiers, modifiers...) + return ns +} diff --git a/ent/node_update.go b/ent/node_update.go index 9fb6016..afcdc11 100644 --- a/ent/node_update.go +++ b/ent/node_update.go @@ -7,9 +7,11 @@ import ( "errors" "fmt" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/predicate" "registry-backend/ent/publisher" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -22,8 +24,9 @@ import ( // NodeUpdate is the builder for updating Node entities. type NodeUpdate struct { config - hooks []Hook - mutation *NodeMutation + hooks []Hook + mutation *NodeMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the NodeUpdate builder. @@ -86,6 +89,26 @@ func (nu *NodeUpdate) ClearDescription() *NodeUpdate { return nu } +// SetCategory sets the "category" field. +func (nu *NodeUpdate) SetCategory(s string) *NodeUpdate { + nu.mutation.SetCategory(s) + return nu +} + +// SetNillableCategory sets the "category" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableCategory(s *string) *NodeUpdate { + if s != nil { + nu.SetCategory(*s) + } + return nu +} + +// ClearCategory clears the value of the "category" field. +func (nu *NodeUpdate) ClearCategory() *NodeUpdate { + nu.mutation.ClearCategory() + return nu +} + // SetAuthor sets the "author" field. func (nu *NodeUpdate) SetAuthor(s string) *NodeUpdate { nu.mutation.SetAuthor(s) @@ -166,6 +189,103 @@ func (nu *NodeUpdate) AppendTags(s []string) *NodeUpdate { return nu } +// SetTotalInstall sets the "total_install" field. +func (nu *NodeUpdate) SetTotalInstall(i int64) *NodeUpdate { + nu.mutation.ResetTotalInstall() + nu.mutation.SetTotalInstall(i) + return nu +} + +// SetNillableTotalInstall sets the "total_install" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableTotalInstall(i *int64) *NodeUpdate { + if i != nil { + nu.SetTotalInstall(*i) + } + return nu +} + +// AddTotalInstall adds i to the "total_install" field. +func (nu *NodeUpdate) AddTotalInstall(i int64) *NodeUpdate { + nu.mutation.AddTotalInstall(i) + return nu +} + +// SetTotalStar sets the "total_star" field. +func (nu *NodeUpdate) SetTotalStar(i int64) *NodeUpdate { + nu.mutation.ResetTotalStar() + nu.mutation.SetTotalStar(i) + return nu +} + +// SetNillableTotalStar sets the "total_star" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableTotalStar(i *int64) *NodeUpdate { + if i != nil { + nu.SetTotalStar(*i) + } + return nu +} + +// AddTotalStar adds i to the "total_star" field. +func (nu *NodeUpdate) AddTotalStar(i int64) *NodeUpdate { + nu.mutation.AddTotalStar(i) + return nu +} + +// SetTotalReview sets the "total_review" field. +func (nu *NodeUpdate) SetTotalReview(i int64) *NodeUpdate { + nu.mutation.ResetTotalReview() + nu.mutation.SetTotalReview(i) + return nu +} + +// SetNillableTotalReview sets the "total_review" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableTotalReview(i *int64) *NodeUpdate { + if i != nil { + nu.SetTotalReview(*i) + } + return nu +} + +// AddTotalReview adds i to the "total_review" field. +func (nu *NodeUpdate) AddTotalReview(i int64) *NodeUpdate { + nu.mutation.AddTotalReview(i) + return nu +} + +// SetStatus sets the "status" field. +func (nu *NodeUpdate) SetStatus(ss schema.NodeStatus) *NodeUpdate { + nu.mutation.SetStatus(ss) + return nu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableStatus(ss *schema.NodeStatus) *NodeUpdate { + if ss != nil { + nu.SetStatus(*ss) + } + return nu +} + +// SetStatusDetail sets the "status_detail" field. +func (nu *NodeUpdate) SetStatusDetail(s string) *NodeUpdate { + nu.mutation.SetStatusDetail(s) + return nu +} + +// SetNillableStatusDetail sets the "status_detail" field if the given value is not nil. +func (nu *NodeUpdate) SetNillableStatusDetail(s *string) *NodeUpdate { + if s != nil { + nu.SetStatusDetail(*s) + } + return nu +} + +// ClearStatusDetail clears the value of the "status_detail" field. +func (nu *NodeUpdate) ClearStatusDetail() *NodeUpdate { + nu.mutation.ClearStatusDetail() + return nu +} + // SetPublisher sets the "publisher" edge to the Publisher entity. func (nu *NodeUpdate) SetPublisher(p *Publisher) *NodeUpdate { return nu.SetPublisherID(p.ID) @@ -186,6 +306,21 @@ func (nu *NodeUpdate) AddVersions(n ...*NodeVersion) *NodeUpdate { return nu.AddVersionIDs(ids...) } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by IDs. +func (nu *NodeUpdate) AddReviewIDs(ids ...uuid.UUID) *NodeUpdate { + nu.mutation.AddReviewIDs(ids...) + return nu +} + +// AddReviews adds the "reviews" edges to the NodeReview entity. +func (nu *NodeUpdate) AddReviews(n ...*NodeReview) *NodeUpdate { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return nu.AddReviewIDs(ids...) +} + // Mutation returns the NodeMutation object of the builder. func (nu *NodeUpdate) Mutation() *NodeMutation { return nu.mutation @@ -218,6 +353,27 @@ func (nu *NodeUpdate) RemoveVersions(n ...*NodeVersion) *NodeUpdate { return nu.RemoveVersionIDs(ids...) } +// ClearReviews clears all "reviews" edges to the NodeReview entity. +func (nu *NodeUpdate) ClearReviews() *NodeUpdate { + nu.mutation.ClearReviews() + return nu +} + +// RemoveReviewIDs removes the "reviews" edge to NodeReview entities by IDs. +func (nu *NodeUpdate) RemoveReviewIDs(ids ...uuid.UUID) *NodeUpdate { + nu.mutation.RemoveReviewIDs(ids...) + return nu +} + +// RemoveReviews removes "reviews" edges to NodeReview entities. +func (nu *NodeUpdate) RemoveReviews(n ...*NodeReview) *NodeUpdate { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return nu.RemoveReviewIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (nu *NodeUpdate) Save(ctx context.Context) (int, error) { nu.defaults() @@ -256,12 +412,23 @@ func (nu *NodeUpdate) defaults() { // check runs all checks and user-defined validators on the builder. func (nu *NodeUpdate) check() error { + if v, ok := nu.mutation.Status(); ok { + if err := node.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Node.status": %w`, err)} + } + } if _, ok := nu.mutation.PublisherID(); nu.mutation.PublisherCleared() && !ok { return errors.New(`ent: clearing a required unique edge "Node.publisher"`) } return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (nu *NodeUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *NodeUpdate { + nu.modifiers = append(nu.modifiers, modifiers...) + return nu +} + func (nu *NodeUpdate) sqlSave(ctx context.Context) (n int, err error) { if err := nu.check(); err != nil { return n, err @@ -286,6 +453,12 @@ func (nu *NodeUpdate) sqlSave(ctx context.Context) (n int, err error) { if nu.mutation.DescriptionCleared() { _spec.ClearField(node.FieldDescription, field.TypeString) } + if value, ok := nu.mutation.Category(); ok { + _spec.SetField(node.FieldCategory, field.TypeString, value) + } + if nu.mutation.CategoryCleared() { + _spec.ClearField(node.FieldCategory, field.TypeString) + } if value, ok := nu.mutation.Author(); ok { _spec.SetField(node.FieldAuthor, field.TypeString, value) } @@ -312,6 +485,33 @@ func (nu *NodeUpdate) sqlSave(ctx context.Context) (n int, err error) { sqljson.Append(u, node.FieldTags, value) }) } + if value, ok := nu.mutation.TotalInstall(); ok { + _spec.SetField(node.FieldTotalInstall, field.TypeInt64, value) + } + if value, ok := nu.mutation.AddedTotalInstall(); ok { + _spec.AddField(node.FieldTotalInstall, field.TypeInt64, value) + } + if value, ok := nu.mutation.TotalStar(); ok { + _spec.SetField(node.FieldTotalStar, field.TypeInt64, value) + } + if value, ok := nu.mutation.AddedTotalStar(); ok { + _spec.AddField(node.FieldTotalStar, field.TypeInt64, value) + } + if value, ok := nu.mutation.TotalReview(); ok { + _spec.SetField(node.FieldTotalReview, field.TypeInt64, value) + } + if value, ok := nu.mutation.AddedTotalReview(); ok { + _spec.AddField(node.FieldTotalReview, field.TypeInt64, value) + } + if value, ok := nu.mutation.Status(); ok { + _spec.SetField(node.FieldStatus, field.TypeEnum, value) + } + if value, ok := nu.mutation.StatusDetail(); ok { + _spec.SetField(node.FieldStatusDetail, field.TypeString, value) + } + if nu.mutation.StatusDetailCleared() { + _spec.ClearField(node.FieldStatusDetail, field.TypeString) + } if nu.mutation.PublisherCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -386,6 +586,52 @@ func (nu *NodeUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if nu.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nu.mutation.RemovedReviewsIDs(); len(nodes) > 0 && !nu.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nu.mutation.ReviewsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _spec.AddModifiers(nu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, nu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{node.Label} @@ -401,9 +647,10 @@ func (nu *NodeUpdate) sqlSave(ctx context.Context) (n int, err error) { // NodeUpdateOne is the builder for updating a single Node entity. type NodeUpdateOne struct { config - fields []string - hooks []Hook - mutation *NodeMutation + fields []string + hooks []Hook + mutation *NodeMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -460,6 +707,26 @@ func (nuo *NodeUpdateOne) ClearDescription() *NodeUpdateOne { return nuo } +// SetCategory sets the "category" field. +func (nuo *NodeUpdateOne) SetCategory(s string) *NodeUpdateOne { + nuo.mutation.SetCategory(s) + return nuo +} + +// SetNillableCategory sets the "category" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableCategory(s *string) *NodeUpdateOne { + if s != nil { + nuo.SetCategory(*s) + } + return nuo +} + +// ClearCategory clears the value of the "category" field. +func (nuo *NodeUpdateOne) ClearCategory() *NodeUpdateOne { + nuo.mutation.ClearCategory() + return nuo +} + // SetAuthor sets the "author" field. func (nuo *NodeUpdateOne) SetAuthor(s string) *NodeUpdateOne { nuo.mutation.SetAuthor(s) @@ -540,6 +807,103 @@ func (nuo *NodeUpdateOne) AppendTags(s []string) *NodeUpdateOne { return nuo } +// SetTotalInstall sets the "total_install" field. +func (nuo *NodeUpdateOne) SetTotalInstall(i int64) *NodeUpdateOne { + nuo.mutation.ResetTotalInstall() + nuo.mutation.SetTotalInstall(i) + return nuo +} + +// SetNillableTotalInstall sets the "total_install" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableTotalInstall(i *int64) *NodeUpdateOne { + if i != nil { + nuo.SetTotalInstall(*i) + } + return nuo +} + +// AddTotalInstall adds i to the "total_install" field. +func (nuo *NodeUpdateOne) AddTotalInstall(i int64) *NodeUpdateOne { + nuo.mutation.AddTotalInstall(i) + return nuo +} + +// SetTotalStar sets the "total_star" field. +func (nuo *NodeUpdateOne) SetTotalStar(i int64) *NodeUpdateOne { + nuo.mutation.ResetTotalStar() + nuo.mutation.SetTotalStar(i) + return nuo +} + +// SetNillableTotalStar sets the "total_star" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableTotalStar(i *int64) *NodeUpdateOne { + if i != nil { + nuo.SetTotalStar(*i) + } + return nuo +} + +// AddTotalStar adds i to the "total_star" field. +func (nuo *NodeUpdateOne) AddTotalStar(i int64) *NodeUpdateOne { + nuo.mutation.AddTotalStar(i) + return nuo +} + +// SetTotalReview sets the "total_review" field. +func (nuo *NodeUpdateOne) SetTotalReview(i int64) *NodeUpdateOne { + nuo.mutation.ResetTotalReview() + nuo.mutation.SetTotalReview(i) + return nuo +} + +// SetNillableTotalReview sets the "total_review" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableTotalReview(i *int64) *NodeUpdateOne { + if i != nil { + nuo.SetTotalReview(*i) + } + return nuo +} + +// AddTotalReview adds i to the "total_review" field. +func (nuo *NodeUpdateOne) AddTotalReview(i int64) *NodeUpdateOne { + nuo.mutation.AddTotalReview(i) + return nuo +} + +// SetStatus sets the "status" field. +func (nuo *NodeUpdateOne) SetStatus(ss schema.NodeStatus) *NodeUpdateOne { + nuo.mutation.SetStatus(ss) + return nuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableStatus(ss *schema.NodeStatus) *NodeUpdateOne { + if ss != nil { + nuo.SetStatus(*ss) + } + return nuo +} + +// SetStatusDetail sets the "status_detail" field. +func (nuo *NodeUpdateOne) SetStatusDetail(s string) *NodeUpdateOne { + nuo.mutation.SetStatusDetail(s) + return nuo +} + +// SetNillableStatusDetail sets the "status_detail" field if the given value is not nil. +func (nuo *NodeUpdateOne) SetNillableStatusDetail(s *string) *NodeUpdateOne { + if s != nil { + nuo.SetStatusDetail(*s) + } + return nuo +} + +// ClearStatusDetail clears the value of the "status_detail" field. +func (nuo *NodeUpdateOne) ClearStatusDetail() *NodeUpdateOne { + nuo.mutation.ClearStatusDetail() + return nuo +} + // SetPublisher sets the "publisher" edge to the Publisher entity. func (nuo *NodeUpdateOne) SetPublisher(p *Publisher) *NodeUpdateOne { return nuo.SetPublisherID(p.ID) @@ -560,6 +924,21 @@ func (nuo *NodeUpdateOne) AddVersions(n ...*NodeVersion) *NodeUpdateOne { return nuo.AddVersionIDs(ids...) } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by IDs. +func (nuo *NodeUpdateOne) AddReviewIDs(ids ...uuid.UUID) *NodeUpdateOne { + nuo.mutation.AddReviewIDs(ids...) + return nuo +} + +// AddReviews adds the "reviews" edges to the NodeReview entity. +func (nuo *NodeUpdateOne) AddReviews(n ...*NodeReview) *NodeUpdateOne { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return nuo.AddReviewIDs(ids...) +} + // Mutation returns the NodeMutation object of the builder. func (nuo *NodeUpdateOne) Mutation() *NodeMutation { return nuo.mutation @@ -592,6 +971,27 @@ func (nuo *NodeUpdateOne) RemoveVersions(n ...*NodeVersion) *NodeUpdateOne { return nuo.RemoveVersionIDs(ids...) } +// ClearReviews clears all "reviews" edges to the NodeReview entity. +func (nuo *NodeUpdateOne) ClearReviews() *NodeUpdateOne { + nuo.mutation.ClearReviews() + return nuo +} + +// RemoveReviewIDs removes the "reviews" edge to NodeReview entities by IDs. +func (nuo *NodeUpdateOne) RemoveReviewIDs(ids ...uuid.UUID) *NodeUpdateOne { + nuo.mutation.RemoveReviewIDs(ids...) + return nuo +} + +// RemoveReviews removes "reviews" edges to NodeReview entities. +func (nuo *NodeUpdateOne) RemoveReviews(n ...*NodeReview) *NodeUpdateOne { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return nuo.RemoveReviewIDs(ids...) +} + // Where appends a list predicates to the NodeUpdate builder. func (nuo *NodeUpdateOne) Where(ps ...predicate.Node) *NodeUpdateOne { nuo.mutation.Where(ps...) @@ -643,12 +1043,23 @@ func (nuo *NodeUpdateOne) defaults() { // check runs all checks and user-defined validators on the builder. func (nuo *NodeUpdateOne) check() error { + if v, ok := nuo.mutation.Status(); ok { + if err := node.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Node.status": %w`, err)} + } + } if _, ok := nuo.mutation.PublisherID(); nuo.mutation.PublisherCleared() && !ok { return errors.New(`ent: clearing a required unique edge "Node.publisher"`) } return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (nuo *NodeUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *NodeUpdateOne { + nuo.modifiers = append(nuo.modifiers, modifiers...) + return nuo +} + func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (_node *Node, err error) { if err := nuo.check(); err != nil { return _node, err @@ -690,6 +1101,12 @@ func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (_node *Node, err error) if nuo.mutation.DescriptionCleared() { _spec.ClearField(node.FieldDescription, field.TypeString) } + if value, ok := nuo.mutation.Category(); ok { + _spec.SetField(node.FieldCategory, field.TypeString, value) + } + if nuo.mutation.CategoryCleared() { + _spec.ClearField(node.FieldCategory, field.TypeString) + } if value, ok := nuo.mutation.Author(); ok { _spec.SetField(node.FieldAuthor, field.TypeString, value) } @@ -716,6 +1133,33 @@ func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (_node *Node, err error) sqljson.Append(u, node.FieldTags, value) }) } + if value, ok := nuo.mutation.TotalInstall(); ok { + _spec.SetField(node.FieldTotalInstall, field.TypeInt64, value) + } + if value, ok := nuo.mutation.AddedTotalInstall(); ok { + _spec.AddField(node.FieldTotalInstall, field.TypeInt64, value) + } + if value, ok := nuo.mutation.TotalStar(); ok { + _spec.SetField(node.FieldTotalStar, field.TypeInt64, value) + } + if value, ok := nuo.mutation.AddedTotalStar(); ok { + _spec.AddField(node.FieldTotalStar, field.TypeInt64, value) + } + if value, ok := nuo.mutation.TotalReview(); ok { + _spec.SetField(node.FieldTotalReview, field.TypeInt64, value) + } + if value, ok := nuo.mutation.AddedTotalReview(); ok { + _spec.AddField(node.FieldTotalReview, field.TypeInt64, value) + } + if value, ok := nuo.mutation.Status(); ok { + _spec.SetField(node.FieldStatus, field.TypeEnum, value) + } + if value, ok := nuo.mutation.StatusDetail(); ok { + _spec.SetField(node.FieldStatusDetail, field.TypeString, value) + } + if nuo.mutation.StatusDetailCleared() { + _spec.ClearField(node.FieldStatusDetail, field.TypeString) + } if nuo.mutation.PublisherCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -790,6 +1234,52 @@ func (nuo *NodeUpdateOne) sqlSave(ctx context.Context) (_node *Node, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if nuo.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nuo.mutation.RemovedReviewsIDs(); len(nodes) > 0 && !nuo.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nuo.mutation.ReviewsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: node.ReviewsTable, + Columns: []string{node.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _spec.AddModifiers(nuo.modifiers...) _node = &Node{config: nuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/nodereview.go b/ent/nodereview.go new file mode 100644 index 0000000..b853dbd --- /dev/null +++ b/ent/nodereview.go @@ -0,0 +1,176 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "registry-backend/ent/node" + "registry-backend/ent/nodereview" + "registry-backend/ent/user" + "strings" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/google/uuid" +) + +// NodeReview is the model entity for the NodeReview schema. +type NodeReview struct { + config `json:"-"` + // ID of the ent. + ID uuid.UUID `json:"id,omitempty"` + // NodeID holds the value of the "node_id" field. + NodeID string `json:"node_id,omitempty"` + // UserID holds the value of the "user_id" field. + UserID string `json:"user_id,omitempty"` + // Star holds the value of the "star" field. + Star int `json:"star,omitempty"` + // Edges holds the relations/edges for other nodes in the graph. + // The values are being populated by the NodeReviewQuery when eager-loading is set. + Edges NodeReviewEdges `json:"edges"` + selectValues sql.SelectValues +} + +// NodeReviewEdges holds the relations/edges for other nodes in the graph. +type NodeReviewEdges struct { + // User holds the value of the user edge. + User *User `json:"user,omitempty"` + // Node holds the value of the node edge. + Node *Node `json:"node,omitempty"` + // loadedTypes holds the information for reporting if a + // type was loaded (or requested) in eager-loading or not. + loadedTypes [2]bool +} + +// UserOrErr returns the User value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e NodeReviewEdges) UserOrErr() (*User, error) { + if e.User != nil { + return e.User, nil + } else if e.loadedTypes[0] { + return nil, &NotFoundError{label: user.Label} + } + return nil, &NotLoadedError{edge: "user"} +} + +// NodeOrErr returns the Node value or an error if the edge +// was not loaded in eager-loading, or loaded but was not found. +func (e NodeReviewEdges) NodeOrErr() (*Node, error) { + if e.Node != nil { + return e.Node, nil + } else if e.loadedTypes[1] { + return nil, &NotFoundError{label: node.Label} + } + return nil, &NotLoadedError{edge: "node"} +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*NodeReview) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case nodereview.FieldStar: + values[i] = new(sql.NullInt64) + case nodereview.FieldNodeID, nodereview.FieldUserID: + values[i] = new(sql.NullString) + case nodereview.FieldID: + values[i] = new(uuid.UUID) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the NodeReview fields. +func (nr *NodeReview) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case nodereview.FieldID: + if value, ok := values[i].(*uuid.UUID); !ok { + return fmt.Errorf("unexpected type %T for field id", values[i]) + } else if value != nil { + nr.ID = *value + } + case nodereview.FieldNodeID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field node_id", values[i]) + } else if value.Valid { + nr.NodeID = value.String + } + case nodereview.FieldUserID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + nr.UserID = value.String + } + case nodereview.FieldStar: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field star", values[i]) + } else if value.Valid { + nr.Star = int(value.Int64) + } + default: + nr.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the NodeReview. +// This includes values selected through modifiers, order, etc. +func (nr *NodeReview) Value(name string) (ent.Value, error) { + return nr.selectValues.Get(name) +} + +// QueryUser queries the "user" edge of the NodeReview entity. +func (nr *NodeReview) QueryUser() *UserQuery { + return NewNodeReviewClient(nr.config).QueryUser(nr) +} + +// QueryNode queries the "node" edge of the NodeReview entity. +func (nr *NodeReview) QueryNode() *NodeQuery { + return NewNodeReviewClient(nr.config).QueryNode(nr) +} + +// Update returns a builder for updating this NodeReview. +// Note that you need to call NodeReview.Unwrap() before calling this method if this NodeReview +// was returned from a transaction, and the transaction was committed or rolled back. +func (nr *NodeReview) Update() *NodeReviewUpdateOne { + return NewNodeReviewClient(nr.config).UpdateOne(nr) +} + +// Unwrap unwraps the NodeReview entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (nr *NodeReview) Unwrap() *NodeReview { + _tx, ok := nr.config.driver.(*txDriver) + if !ok { + panic("ent: NodeReview is not a transactional entity") + } + nr.config.driver = _tx.drv + return nr +} + +// String implements the fmt.Stringer. +func (nr *NodeReview) String() string { + var builder strings.Builder + builder.WriteString("NodeReview(") + builder.WriteString(fmt.Sprintf("id=%v, ", nr.ID)) + builder.WriteString("node_id=") + builder.WriteString(nr.NodeID) + builder.WriteString(", ") + builder.WriteString("user_id=") + builder.WriteString(nr.UserID) + builder.WriteString(", ") + builder.WriteString("star=") + builder.WriteString(fmt.Sprintf("%v", nr.Star)) + builder.WriteByte(')') + return builder.String() +} + +// NodeReviews is a parsable slice of NodeReview. +type NodeReviews []*NodeReview diff --git a/ent/nodereview/nodereview.go b/ent/nodereview/nodereview.go new file mode 100644 index 0000000..b4ef126 --- /dev/null +++ b/ent/nodereview/nodereview.go @@ -0,0 +1,118 @@ +// Code generated by ent, DO NOT EDIT. + +package nodereview + +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/google/uuid" +) + +const ( + // Label holds the string label denoting the nodereview type in the database. + Label = "node_review" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldNodeID holds the string denoting the node_id field in the database. + FieldNodeID = "node_id" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldStar holds the string denoting the star field in the database. + FieldStar = "star" + // EdgeUser holds the string denoting the user edge name in mutations. + EdgeUser = "user" + // EdgeNode holds the string denoting the node edge name in mutations. + EdgeNode = "node" + // Table holds the table name of the nodereview in the database. + Table = "node_reviews" + // UserTable is the table that holds the user relation/edge. + UserTable = "node_reviews" + // UserInverseTable is the table name for the User entity. + // It exists in this package in order to avoid circular dependency with the "user" package. + UserInverseTable = "users" + // UserColumn is the table column denoting the user relation/edge. + UserColumn = "user_id" + // NodeTable is the table that holds the node relation/edge. + NodeTable = "node_reviews" + // NodeInverseTable is the table name for the Node entity. + // It exists in this package in order to avoid circular dependency with the "node" package. + NodeInverseTable = "nodes" + // NodeColumn is the table column denoting the node relation/edge. + NodeColumn = "node_id" +) + +// Columns holds all SQL columns for nodereview fields. +var Columns = []string{ + FieldID, + FieldNodeID, + FieldUserID, + FieldStar, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultStar holds the default value on creation for the "star" field. + DefaultStar int + // DefaultID holds the default value on creation for the "id" field. + DefaultID func() uuid.UUID +) + +// OrderOption defines the ordering options for the NodeReview queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByNodeID orders the results by the node_id field. +func ByNodeID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldNodeID, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByStar orders the results by the star field. +func ByStar(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStar, opts...).ToFunc() +} + +// ByUserField orders the results by user field. +func ByUserField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newUserStep(), sql.OrderByField(field, opts...)) + } +} + +// ByNodeField orders the results by node field. +func ByNodeField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newNodeStep(), sql.OrderByField(field, opts...)) + } +} +func newUserStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(UserInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) +} +func newNodeStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(NodeInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, NodeTable, NodeColumn), + ) +} diff --git a/ent/nodereview/where.go b/ent/nodereview/where.go new file mode 100644 index 0000000..59b36a6 --- /dev/null +++ b/ent/nodereview/where.go @@ -0,0 +1,302 @@ +// Code generated by ent, DO NOT EDIT. + +package nodereview + +import ( + "registry-backend/ent/predicate" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/google/uuid" +) + +// ID filters vertices based on their ID field. +func ID(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id uuid.UUID) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLTE(FieldID, id)) +} + +// NodeID applies equality check predicate on the "node_id" field. It's identical to NodeIDEQ. +func NodeID(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldNodeID, v)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldUserID, v)) +} + +// Star applies equality check predicate on the "star" field. It's identical to StarEQ. +func Star(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldStar, v)) +} + +// NodeIDEQ applies the EQ predicate on the "node_id" field. +func NodeIDEQ(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldNodeID, v)) +} + +// NodeIDNEQ applies the NEQ predicate on the "node_id" field. +func NodeIDNEQ(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNEQ(FieldNodeID, v)) +} + +// NodeIDIn applies the In predicate on the "node_id" field. +func NodeIDIn(vs ...string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldIn(FieldNodeID, vs...)) +} + +// NodeIDNotIn applies the NotIn predicate on the "node_id" field. +func NodeIDNotIn(vs ...string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNotIn(FieldNodeID, vs...)) +} + +// NodeIDGT applies the GT predicate on the "node_id" field. +func NodeIDGT(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGT(FieldNodeID, v)) +} + +// NodeIDGTE applies the GTE predicate on the "node_id" field. +func NodeIDGTE(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGTE(FieldNodeID, v)) +} + +// NodeIDLT applies the LT predicate on the "node_id" field. +func NodeIDLT(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLT(FieldNodeID, v)) +} + +// NodeIDLTE applies the LTE predicate on the "node_id" field. +func NodeIDLTE(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLTE(FieldNodeID, v)) +} + +// NodeIDContains applies the Contains predicate on the "node_id" field. +func NodeIDContains(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldContains(FieldNodeID, v)) +} + +// NodeIDHasPrefix applies the HasPrefix predicate on the "node_id" field. +func NodeIDHasPrefix(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldHasPrefix(FieldNodeID, v)) +} + +// NodeIDHasSuffix applies the HasSuffix predicate on the "node_id" field. +func NodeIDHasSuffix(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldHasSuffix(FieldNodeID, v)) +} + +// NodeIDEqualFold applies the EqualFold predicate on the "node_id" field. +func NodeIDEqualFold(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEqualFold(FieldNodeID, v)) +} + +// NodeIDContainsFold applies the ContainsFold predicate on the "node_id" field. +func NodeIDContainsFold(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldContainsFold(FieldNodeID, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNotIn(FieldUserID, vs...)) +} + +// UserIDGT applies the GT predicate on the "user_id" field. +func UserIDGT(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGT(FieldUserID, v)) +} + +// UserIDGTE applies the GTE predicate on the "user_id" field. +func UserIDGTE(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGTE(FieldUserID, v)) +} + +// UserIDLT applies the LT predicate on the "user_id" field. +func UserIDLT(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLT(FieldUserID, v)) +} + +// UserIDLTE applies the LTE predicate on the "user_id" field. +func UserIDLTE(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLTE(FieldUserID, v)) +} + +// UserIDContains applies the Contains predicate on the "user_id" field. +func UserIDContains(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldContains(FieldUserID, v)) +} + +// UserIDHasPrefix applies the HasPrefix predicate on the "user_id" field. +func UserIDHasPrefix(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldHasPrefix(FieldUserID, v)) +} + +// UserIDHasSuffix applies the HasSuffix predicate on the "user_id" field. +func UserIDHasSuffix(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldHasSuffix(FieldUserID, v)) +} + +// UserIDEqualFold applies the EqualFold predicate on the "user_id" field. +func UserIDEqualFold(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEqualFold(FieldUserID, v)) +} + +// UserIDContainsFold applies the ContainsFold predicate on the "user_id" field. +func UserIDContainsFold(v string) predicate.NodeReview { + return predicate.NodeReview(sql.FieldContainsFold(FieldUserID, v)) +} + +// StarEQ applies the EQ predicate on the "star" field. +func StarEQ(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldEQ(FieldStar, v)) +} + +// StarNEQ applies the NEQ predicate on the "star" field. +func StarNEQ(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNEQ(FieldStar, v)) +} + +// StarIn applies the In predicate on the "star" field. +func StarIn(vs ...int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldIn(FieldStar, vs...)) +} + +// StarNotIn applies the NotIn predicate on the "star" field. +func StarNotIn(vs ...int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldNotIn(FieldStar, vs...)) +} + +// StarGT applies the GT predicate on the "star" field. +func StarGT(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGT(FieldStar, v)) +} + +// StarGTE applies the GTE predicate on the "star" field. +func StarGTE(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldGTE(FieldStar, v)) +} + +// StarLT applies the LT predicate on the "star" field. +func StarLT(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLT(FieldStar, v)) +} + +// StarLTE applies the LTE predicate on the "star" field. +func StarLTE(v int) predicate.NodeReview { + return predicate.NodeReview(sql.FieldLTE(FieldStar, v)) +} + +// HasUser applies the HasEdge predicate on the "user" edge. +func HasUser() predicate.NodeReview { + return predicate.NodeReview(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, UserTable, UserColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasUserWith applies the HasEdge predicate on the "user" edge with a given conditions (other predicates). +func HasUserWith(preds ...predicate.User) predicate.NodeReview { + return predicate.NodeReview(func(s *sql.Selector) { + step := newUserStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// HasNode applies the HasEdge predicate on the "node" edge. +func HasNode() predicate.NodeReview { + return predicate.NodeReview(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, NodeTable, NodeColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasNodeWith applies the HasEdge predicate on the "node" edge with a given conditions (other predicates). +func HasNodeWith(preds ...predicate.Node) predicate.NodeReview { + return predicate.NodeReview(func(s *sql.Selector) { + step := newNodeStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.NodeReview) predicate.NodeReview { + return predicate.NodeReview(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.NodeReview) predicate.NodeReview { + return predicate.NodeReview(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.NodeReview) predicate.NodeReview { + return predicate.NodeReview(sql.NotPredicates(p)) +} diff --git a/ent/nodereview_create.go b/ent/nodereview_create.go new file mode 100644 index 0000000..4b86b2e --- /dev/null +++ b/ent/nodereview_create.go @@ -0,0 +1,690 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "registry-backend/ent/node" + "registry-backend/ent/nodereview" + "registry-backend/ent/user" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/google/uuid" +) + +// NodeReviewCreate is the builder for creating a NodeReview entity. +type NodeReviewCreate struct { + config + mutation *NodeReviewMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetNodeID sets the "node_id" field. +func (nrc *NodeReviewCreate) SetNodeID(s string) *NodeReviewCreate { + nrc.mutation.SetNodeID(s) + return nrc +} + +// SetUserID sets the "user_id" field. +func (nrc *NodeReviewCreate) SetUserID(s string) *NodeReviewCreate { + nrc.mutation.SetUserID(s) + return nrc +} + +// SetStar sets the "star" field. +func (nrc *NodeReviewCreate) SetStar(i int) *NodeReviewCreate { + nrc.mutation.SetStar(i) + return nrc +} + +// SetNillableStar sets the "star" field if the given value is not nil. +func (nrc *NodeReviewCreate) SetNillableStar(i *int) *NodeReviewCreate { + if i != nil { + nrc.SetStar(*i) + } + return nrc +} + +// SetID sets the "id" field. +func (nrc *NodeReviewCreate) SetID(u uuid.UUID) *NodeReviewCreate { + nrc.mutation.SetID(u) + return nrc +} + +// SetNillableID sets the "id" field if the given value is not nil. +func (nrc *NodeReviewCreate) SetNillableID(u *uuid.UUID) *NodeReviewCreate { + if u != nil { + nrc.SetID(*u) + } + return nrc +} + +// SetUser sets the "user" edge to the User entity. +func (nrc *NodeReviewCreate) SetUser(u *User) *NodeReviewCreate { + return nrc.SetUserID(u.ID) +} + +// SetNode sets the "node" edge to the Node entity. +func (nrc *NodeReviewCreate) SetNode(n *Node) *NodeReviewCreate { + return nrc.SetNodeID(n.ID) +} + +// Mutation returns the NodeReviewMutation object of the builder. +func (nrc *NodeReviewCreate) Mutation() *NodeReviewMutation { + return nrc.mutation +} + +// Save creates the NodeReview in the database. +func (nrc *NodeReviewCreate) Save(ctx context.Context) (*NodeReview, error) { + nrc.defaults() + return withHooks(ctx, nrc.sqlSave, nrc.mutation, nrc.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (nrc *NodeReviewCreate) SaveX(ctx context.Context) *NodeReview { + v, err := nrc.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (nrc *NodeReviewCreate) Exec(ctx context.Context) error { + _, err := nrc.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nrc *NodeReviewCreate) ExecX(ctx context.Context) { + if err := nrc.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (nrc *NodeReviewCreate) defaults() { + if _, ok := nrc.mutation.Star(); !ok { + v := nodereview.DefaultStar + nrc.mutation.SetStar(v) + } + if _, ok := nrc.mutation.ID(); !ok { + v := nodereview.DefaultID() + nrc.mutation.SetID(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (nrc *NodeReviewCreate) check() error { + if _, ok := nrc.mutation.NodeID(); !ok { + return &ValidationError{Name: "node_id", err: errors.New(`ent: missing required field "NodeReview.node_id"`)} + } + if _, ok := nrc.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`ent: missing required field "NodeReview.user_id"`)} + } + if _, ok := nrc.mutation.Star(); !ok { + return &ValidationError{Name: "star", err: errors.New(`ent: missing required field "NodeReview.star"`)} + } + if _, ok := nrc.mutation.UserID(); !ok { + return &ValidationError{Name: "user", err: errors.New(`ent: missing required edge "NodeReview.user"`)} + } + if _, ok := nrc.mutation.NodeID(); !ok { + return &ValidationError{Name: "node", err: errors.New(`ent: missing required edge "NodeReview.node"`)} + } + return nil +} + +func (nrc *NodeReviewCreate) sqlSave(ctx context.Context) (*NodeReview, error) { + if err := nrc.check(); err != nil { + return nil, err + } + _node, _spec := nrc.createSpec() + if err := sqlgraph.CreateNode(ctx, nrc.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(*uuid.UUID); ok { + _node.ID = *id + } else if err := _node.ID.Scan(_spec.ID.Value); err != nil { + return nil, err + } + } + nrc.mutation.id = &_node.ID + nrc.mutation.done = true + return _node, nil +} + +func (nrc *NodeReviewCreate) createSpec() (*NodeReview, *sqlgraph.CreateSpec) { + var ( + _node = &NodeReview{config: nrc.config} + _spec = sqlgraph.NewCreateSpec(nodereview.Table, sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID)) + ) + _spec.OnConflict = nrc.conflict + if id, ok := nrc.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = &id + } + if value, ok := nrc.mutation.Star(); ok { + _spec.SetField(nodereview.FieldStar, field.TypeInt, value) + _node.Star = value + } + if nodes := nrc.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.UserTable, + Columns: []string{nodereview.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.UserID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + if nodes := nrc.mutation.NodeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.NodeTable, + Columns: []string{nodereview.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _node.NodeID = nodes[0] + _spec.Edges = append(_spec.Edges, edge) + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.NodeReview.Create(). +// SetNodeID(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.NodeReviewUpsert) { +// SetNodeID(v+v). +// }). +// Exec(ctx) +func (nrc *NodeReviewCreate) OnConflict(opts ...sql.ConflictOption) *NodeReviewUpsertOne { + nrc.conflict = opts + return &NodeReviewUpsertOne{ + create: nrc, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.NodeReview.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (nrc *NodeReviewCreate) OnConflictColumns(columns ...string) *NodeReviewUpsertOne { + nrc.conflict = append(nrc.conflict, sql.ConflictColumns(columns...)) + return &NodeReviewUpsertOne{ + create: nrc, + } +} + +type ( + // NodeReviewUpsertOne is the builder for "upsert"-ing + // one NodeReview node. + NodeReviewUpsertOne struct { + create *NodeReviewCreate + } + + // NodeReviewUpsert is the "OnConflict" setter. + NodeReviewUpsert struct { + *sql.UpdateSet + } +) + +// SetNodeID sets the "node_id" field. +func (u *NodeReviewUpsert) SetNodeID(v string) *NodeReviewUpsert { + u.Set(nodereview.FieldNodeID, v) + return u +} + +// UpdateNodeID sets the "node_id" field to the value that was provided on create. +func (u *NodeReviewUpsert) UpdateNodeID() *NodeReviewUpsert { + u.SetExcluded(nodereview.FieldNodeID) + return u +} + +// SetUserID sets the "user_id" field. +func (u *NodeReviewUpsert) SetUserID(v string) *NodeReviewUpsert { + u.Set(nodereview.FieldUserID, v) + return u +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *NodeReviewUpsert) UpdateUserID() *NodeReviewUpsert { + u.SetExcluded(nodereview.FieldUserID) + return u +} + +// SetStar sets the "star" field. +func (u *NodeReviewUpsert) SetStar(v int) *NodeReviewUpsert { + u.Set(nodereview.FieldStar, v) + return u +} + +// UpdateStar sets the "star" field to the value that was provided on create. +func (u *NodeReviewUpsert) UpdateStar() *NodeReviewUpsert { + u.SetExcluded(nodereview.FieldStar) + return u +} + +// AddStar adds v to the "star" field. +func (u *NodeReviewUpsert) AddStar(v int) *NodeReviewUpsert { + u.Add(nodereview.FieldStar, v) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. +// Using this option is equivalent to using: +// +// client.NodeReview.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// sql.ResolveWith(func(u *sql.UpdateSet) { +// u.SetIgnore(nodereview.FieldID) +// }), +// ). +// Exec(ctx) +func (u *NodeReviewUpsertOne) UpdateNewValues() *NodeReviewUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.ID(); exists { + s.SetIgnore(nodereview.FieldID) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.NodeReview.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *NodeReviewUpsertOne) Ignore() *NodeReviewUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *NodeReviewUpsertOne) DoNothing() *NodeReviewUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the NodeReviewCreate.OnConflict +// documentation for more info. +func (u *NodeReviewUpsertOne) Update(set func(*NodeReviewUpsert)) *NodeReviewUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&NodeReviewUpsert{UpdateSet: update}) + })) + return u +} + +// SetNodeID sets the "node_id" field. +func (u *NodeReviewUpsertOne) SetNodeID(v string) *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.SetNodeID(v) + }) +} + +// UpdateNodeID sets the "node_id" field to the value that was provided on create. +func (u *NodeReviewUpsertOne) UpdateNodeID() *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.UpdateNodeID() + }) +} + +// SetUserID sets the "user_id" field. +func (u *NodeReviewUpsertOne) SetUserID(v string) *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *NodeReviewUpsertOne) UpdateUserID() *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.UpdateUserID() + }) +} + +// SetStar sets the "star" field. +func (u *NodeReviewUpsertOne) SetStar(v int) *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.SetStar(v) + }) +} + +// AddStar adds v to the "star" field. +func (u *NodeReviewUpsertOne) AddStar(v int) *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.AddStar(v) + }) +} + +// UpdateStar sets the "star" field to the value that was provided on create. +func (u *NodeReviewUpsertOne) UpdateStar() *NodeReviewUpsertOne { + return u.Update(func(s *NodeReviewUpsert) { + s.UpdateStar() + }) +} + +// Exec executes the query. +func (u *NodeReviewUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for NodeReviewCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *NodeReviewUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *NodeReviewUpsertOne) ID(ctx context.Context) (id uuid.UUID, err error) { + if u.create.driver.Dialect() == dialect.MySQL { + // In case of "ON CONFLICT", there is no way to get back non-numeric ID + // fields from the database since MySQL does not support the RETURNING clause. + return id, errors.New("ent: NodeReviewUpsertOne.ID is not supported by MySQL driver. Use NodeReviewUpsertOne.Exec instead") + } + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *NodeReviewUpsertOne) IDX(ctx context.Context) uuid.UUID { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// NodeReviewCreateBulk is the builder for creating many NodeReview entities in bulk. +type NodeReviewCreateBulk struct { + config + err error + builders []*NodeReviewCreate + conflict []sql.ConflictOption +} + +// Save creates the NodeReview entities in the database. +func (nrcb *NodeReviewCreateBulk) Save(ctx context.Context) ([]*NodeReview, error) { + if nrcb.err != nil { + return nil, nrcb.err + } + specs := make([]*sqlgraph.CreateSpec, len(nrcb.builders)) + nodes := make([]*NodeReview, len(nrcb.builders)) + mutators := make([]Mutator, len(nrcb.builders)) + for i := range nrcb.builders { + func(i int, root context.Context) { + builder := nrcb.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*NodeReviewMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, nrcb.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = nrcb.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, nrcb.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, nrcb.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (nrcb *NodeReviewCreateBulk) SaveX(ctx context.Context) []*NodeReview { + v, err := nrcb.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (nrcb *NodeReviewCreateBulk) Exec(ctx context.Context) error { + _, err := nrcb.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nrcb *NodeReviewCreateBulk) ExecX(ctx context.Context) { + if err := nrcb.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.NodeReview.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.NodeReviewUpsert) { +// SetNodeID(v+v). +// }). +// Exec(ctx) +func (nrcb *NodeReviewCreateBulk) OnConflict(opts ...sql.ConflictOption) *NodeReviewUpsertBulk { + nrcb.conflict = opts + return &NodeReviewUpsertBulk{ + create: nrcb, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.NodeReview.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (nrcb *NodeReviewCreateBulk) OnConflictColumns(columns ...string) *NodeReviewUpsertBulk { + nrcb.conflict = append(nrcb.conflict, sql.ConflictColumns(columns...)) + return &NodeReviewUpsertBulk{ + create: nrcb, + } +} + +// NodeReviewUpsertBulk is the builder for "upsert"-ing +// a bulk of NodeReview nodes. +type NodeReviewUpsertBulk struct { + create *NodeReviewCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.NodeReview.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// sql.ResolveWith(func(u *sql.UpdateSet) { +// u.SetIgnore(nodereview.FieldID) +// }), +// ). +// Exec(ctx) +func (u *NodeReviewUpsertBulk) UpdateNewValues() *NodeReviewUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.ID(); exists { + s.SetIgnore(nodereview.FieldID) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.NodeReview.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *NodeReviewUpsertBulk) Ignore() *NodeReviewUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *NodeReviewUpsertBulk) DoNothing() *NodeReviewUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the NodeReviewCreateBulk.OnConflict +// documentation for more info. +func (u *NodeReviewUpsertBulk) Update(set func(*NodeReviewUpsert)) *NodeReviewUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&NodeReviewUpsert{UpdateSet: update}) + })) + return u +} + +// SetNodeID sets the "node_id" field. +func (u *NodeReviewUpsertBulk) SetNodeID(v string) *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.SetNodeID(v) + }) +} + +// UpdateNodeID sets the "node_id" field to the value that was provided on create. +func (u *NodeReviewUpsertBulk) UpdateNodeID() *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.UpdateNodeID() + }) +} + +// SetUserID sets the "user_id" field. +func (u *NodeReviewUpsertBulk) SetUserID(v string) *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.SetUserID(v) + }) +} + +// UpdateUserID sets the "user_id" field to the value that was provided on create. +func (u *NodeReviewUpsertBulk) UpdateUserID() *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.UpdateUserID() + }) +} + +// SetStar sets the "star" field. +func (u *NodeReviewUpsertBulk) SetStar(v int) *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.SetStar(v) + }) +} + +// AddStar adds v to the "star" field. +func (u *NodeReviewUpsertBulk) AddStar(v int) *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.AddStar(v) + }) +} + +// UpdateStar sets the "star" field to the value that was provided on create. +func (u *NodeReviewUpsertBulk) UpdateStar() *NodeReviewUpsertBulk { + return u.Update(func(s *NodeReviewUpsert) { + s.UpdateStar() + }) +} + +// Exec executes the query. +func (u *NodeReviewUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the NodeReviewCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for NodeReviewCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *NodeReviewUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/nodereview_delete.go b/ent/nodereview_delete.go new file mode 100644 index 0000000..f918003 --- /dev/null +++ b/ent/nodereview_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "registry-backend/ent/nodereview" + "registry-backend/ent/predicate" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" +) + +// NodeReviewDelete is the builder for deleting a NodeReview entity. +type NodeReviewDelete struct { + config + hooks []Hook + mutation *NodeReviewMutation +} + +// Where appends a list predicates to the NodeReviewDelete builder. +func (nrd *NodeReviewDelete) Where(ps ...predicate.NodeReview) *NodeReviewDelete { + nrd.mutation.Where(ps...) + return nrd +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (nrd *NodeReviewDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, nrd.sqlExec, nrd.mutation, nrd.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (nrd *NodeReviewDelete) ExecX(ctx context.Context) int { + n, err := nrd.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (nrd *NodeReviewDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(nodereview.Table, sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID)) + if ps := nrd.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, nrd.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + nrd.mutation.done = true + return affected, err +} + +// NodeReviewDeleteOne is the builder for deleting a single NodeReview entity. +type NodeReviewDeleteOne struct { + nrd *NodeReviewDelete +} + +// Where appends a list predicates to the NodeReviewDelete builder. +func (nrdo *NodeReviewDeleteOne) Where(ps ...predicate.NodeReview) *NodeReviewDeleteOne { + nrdo.nrd.mutation.Where(ps...) + return nrdo +} + +// Exec executes the deletion query. +func (nrdo *NodeReviewDeleteOne) Exec(ctx context.Context) error { + n, err := nrdo.nrd.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{nodereview.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (nrdo *NodeReviewDeleteOne) ExecX(ctx context.Context) { + if err := nrdo.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/ent/nodereview_query.go b/ent/nodereview_query.go new file mode 100644 index 0000000..9515939 --- /dev/null +++ b/ent/nodereview_query.go @@ -0,0 +1,730 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + "registry-backend/ent/node" + "registry-backend/ent/nodereview" + "registry-backend/ent/predicate" + "registry-backend/ent/user" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/google/uuid" +) + +// NodeReviewQuery is the builder for querying NodeReview entities. +type NodeReviewQuery struct { + config + ctx *QueryContext + order []nodereview.OrderOption + inters []Interceptor + predicates []predicate.NodeReview + withUser *UserQuery + withNode *NodeQuery + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the NodeReviewQuery builder. +func (nrq *NodeReviewQuery) Where(ps ...predicate.NodeReview) *NodeReviewQuery { + nrq.predicates = append(nrq.predicates, ps...) + return nrq +} + +// Limit the number of records to be returned by this query. +func (nrq *NodeReviewQuery) Limit(limit int) *NodeReviewQuery { + nrq.ctx.Limit = &limit + return nrq +} + +// Offset to start from. +func (nrq *NodeReviewQuery) Offset(offset int) *NodeReviewQuery { + nrq.ctx.Offset = &offset + return nrq +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (nrq *NodeReviewQuery) Unique(unique bool) *NodeReviewQuery { + nrq.ctx.Unique = &unique + return nrq +} + +// Order specifies how the records should be ordered. +func (nrq *NodeReviewQuery) Order(o ...nodereview.OrderOption) *NodeReviewQuery { + nrq.order = append(nrq.order, o...) + return nrq +} + +// QueryUser chains the current query on the "user" edge. +func (nrq *NodeReviewQuery) QueryUser() *UserQuery { + query := (&UserClient{config: nrq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := nrq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := nrq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(nodereview.Table, nodereview.FieldID, selector), + sqlgraph.To(user.Table, user.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, nodereview.UserTable, nodereview.UserColumn), + ) + fromU = sqlgraph.SetNeighbors(nrq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// QueryNode chains the current query on the "node" edge. +func (nrq *NodeReviewQuery) QueryNode() *NodeQuery { + query := (&NodeClient{config: nrq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := nrq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := nrq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(nodereview.Table, nodereview.FieldID, selector), + sqlgraph.To(node.Table, node.FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, nodereview.NodeTable, nodereview.NodeColumn), + ) + fromU = sqlgraph.SetNeighbors(nrq.driver.Dialect(), step) + return fromU, nil + } + return query +} + +// First returns the first NodeReview entity from the query. +// Returns a *NotFoundError when no NodeReview was found. +func (nrq *NodeReviewQuery) First(ctx context.Context) (*NodeReview, error) { + nodes, err := nrq.Limit(1).All(setContextOp(ctx, nrq.ctx, "First")) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{nodereview.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (nrq *NodeReviewQuery) FirstX(ctx context.Context) *NodeReview { + node, err := nrq.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first NodeReview ID from the query. +// Returns a *NotFoundError when no NodeReview ID was found. +func (nrq *NodeReviewQuery) FirstID(ctx context.Context) (id uuid.UUID, err error) { + var ids []uuid.UUID + if ids, err = nrq.Limit(1).IDs(setContextOp(ctx, nrq.ctx, "FirstID")); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{nodereview.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (nrq *NodeReviewQuery) FirstIDX(ctx context.Context) uuid.UUID { + id, err := nrq.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single NodeReview entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one NodeReview entity is found. +// Returns a *NotFoundError when no NodeReview entities are found. +func (nrq *NodeReviewQuery) Only(ctx context.Context) (*NodeReview, error) { + nodes, err := nrq.Limit(2).All(setContextOp(ctx, nrq.ctx, "Only")) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{nodereview.Label} + default: + return nil, &NotSingularError{nodereview.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (nrq *NodeReviewQuery) OnlyX(ctx context.Context) *NodeReview { + node, err := nrq.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only NodeReview ID in the query. +// Returns a *NotSingularError when more than one NodeReview ID is found. +// Returns a *NotFoundError when no entities are found. +func (nrq *NodeReviewQuery) OnlyID(ctx context.Context) (id uuid.UUID, err error) { + var ids []uuid.UUID + if ids, err = nrq.Limit(2).IDs(setContextOp(ctx, nrq.ctx, "OnlyID")); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{nodereview.Label} + default: + err = &NotSingularError{nodereview.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (nrq *NodeReviewQuery) OnlyIDX(ctx context.Context) uuid.UUID { + id, err := nrq.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of NodeReviews. +func (nrq *NodeReviewQuery) All(ctx context.Context) ([]*NodeReview, error) { + ctx = setContextOp(ctx, nrq.ctx, "All") + if err := nrq.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*NodeReview, *NodeReviewQuery]() + return withInterceptors[[]*NodeReview](ctx, nrq, qr, nrq.inters) +} + +// AllX is like All, but panics if an error occurs. +func (nrq *NodeReviewQuery) AllX(ctx context.Context) []*NodeReview { + nodes, err := nrq.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of NodeReview IDs. +func (nrq *NodeReviewQuery) IDs(ctx context.Context) (ids []uuid.UUID, err error) { + if nrq.ctx.Unique == nil && nrq.path != nil { + nrq.Unique(true) + } + ctx = setContextOp(ctx, nrq.ctx, "IDs") + if err = nrq.Select(nodereview.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (nrq *NodeReviewQuery) IDsX(ctx context.Context) []uuid.UUID { + ids, err := nrq.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (nrq *NodeReviewQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, nrq.ctx, "Count") + if err := nrq.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, nrq, querierCount[*NodeReviewQuery](), nrq.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (nrq *NodeReviewQuery) CountX(ctx context.Context) int { + count, err := nrq.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (nrq *NodeReviewQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, nrq.ctx, "Exist") + switch _, err := nrq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (nrq *NodeReviewQuery) ExistX(ctx context.Context) bool { + exist, err := nrq.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the NodeReviewQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (nrq *NodeReviewQuery) Clone() *NodeReviewQuery { + if nrq == nil { + return nil + } + return &NodeReviewQuery{ + config: nrq.config, + ctx: nrq.ctx.Clone(), + order: append([]nodereview.OrderOption{}, nrq.order...), + inters: append([]Interceptor{}, nrq.inters...), + predicates: append([]predicate.NodeReview{}, nrq.predicates...), + withUser: nrq.withUser.Clone(), + withNode: nrq.withNode.Clone(), + // clone intermediate query. + sql: nrq.sql.Clone(), + path: nrq.path, + } +} + +// WithUser tells the query-builder to eager-load the nodes that are connected to +// the "user" edge. The optional arguments are used to configure the query builder of the edge. +func (nrq *NodeReviewQuery) WithUser(opts ...func(*UserQuery)) *NodeReviewQuery { + query := (&UserClient{config: nrq.config}).Query() + for _, opt := range opts { + opt(query) + } + nrq.withUser = query + return nrq +} + +// WithNode tells the query-builder to eager-load the nodes that are connected to +// the "node" edge. The optional arguments are used to configure the query builder of the edge. +func (nrq *NodeReviewQuery) WithNode(opts ...func(*NodeQuery)) *NodeReviewQuery { + query := (&NodeClient{config: nrq.config}).Query() + for _, opt := range opts { + opt(query) + } + nrq.withNode = query + return nrq +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// NodeID string `json:"node_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.NodeReview.Query(). +// GroupBy(nodereview.FieldNodeID). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (nrq *NodeReviewQuery) GroupBy(field string, fields ...string) *NodeReviewGroupBy { + nrq.ctx.Fields = append([]string{field}, fields...) + grbuild := &NodeReviewGroupBy{build: nrq} + grbuild.flds = &nrq.ctx.Fields + grbuild.label = nodereview.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// NodeID string `json:"node_id,omitempty"` +// } +// +// client.NodeReview.Query(). +// Select(nodereview.FieldNodeID). +// Scan(ctx, &v) +func (nrq *NodeReviewQuery) Select(fields ...string) *NodeReviewSelect { + nrq.ctx.Fields = append(nrq.ctx.Fields, fields...) + sbuild := &NodeReviewSelect{NodeReviewQuery: nrq} + sbuild.label = nodereview.Label + sbuild.flds, sbuild.scan = &nrq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a NodeReviewSelect configured with the given aggregations. +func (nrq *NodeReviewQuery) Aggregate(fns ...AggregateFunc) *NodeReviewSelect { + return nrq.Select().Aggregate(fns...) +} + +func (nrq *NodeReviewQuery) prepareQuery(ctx context.Context) error { + for _, inter := range nrq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, nrq); err != nil { + return err + } + } + } + for _, f := range nrq.ctx.Fields { + if !nodereview.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if nrq.path != nil { + prev, err := nrq.path(ctx) + if err != nil { + return err + } + nrq.sql = prev + } + return nil +} + +func (nrq *NodeReviewQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*NodeReview, error) { + var ( + nodes = []*NodeReview{} + _spec = nrq.querySpec() + loadedTypes = [2]bool{ + nrq.withUser != nil, + nrq.withNode != nil, + } + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*NodeReview).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &NodeReview{config: nrq.config} + nodes = append(nodes, node) + node.Edges.loadedTypes = loadedTypes + return node.assignValues(columns, values) + } + if len(nrq.modifiers) > 0 { + _spec.Modifiers = nrq.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, nrq.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + if query := nrq.withUser; query != nil { + if err := nrq.loadUser(ctx, query, nodes, nil, + func(n *NodeReview, e *User) { n.Edges.User = e }); err != nil { + return nil, err + } + } + if query := nrq.withNode; query != nil { + if err := nrq.loadNode(ctx, query, nodes, nil, + func(n *NodeReview, e *Node) { n.Edges.Node = e }); err != nil { + return nil, err + } + } + return nodes, nil +} + +func (nrq *NodeReviewQuery) loadUser(ctx context.Context, query *UserQuery, nodes []*NodeReview, init func(*NodeReview), assign func(*NodeReview, *User)) error { + ids := make([]string, 0, len(nodes)) + nodeids := make(map[string][]*NodeReview) + for i := range nodes { + fk := nodes[i].UserID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(user.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "user_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} +func (nrq *NodeReviewQuery) loadNode(ctx context.Context, query *NodeQuery, nodes []*NodeReview, init func(*NodeReview), assign func(*NodeReview, *Node)) error { + ids := make([]string, 0, len(nodes)) + nodeids := make(map[string][]*NodeReview) + for i := range nodes { + fk := nodes[i].NodeID + if _, ok := nodeids[fk]; !ok { + ids = append(ids, fk) + } + nodeids[fk] = append(nodeids[fk], nodes[i]) + } + if len(ids) == 0 { + return nil + } + query.Where(node.IDIn(ids...)) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + nodes, ok := nodeids[n.ID] + if !ok { + return fmt.Errorf(`unexpected foreign-key "node_id" returned %v`, n.ID) + } + for i := range nodes { + assign(nodes[i], n) + } + } + return nil +} + +func (nrq *NodeReviewQuery) sqlCount(ctx context.Context) (int, error) { + _spec := nrq.querySpec() + if len(nrq.modifiers) > 0 { + _spec.Modifiers = nrq.modifiers + } + _spec.Node.Columns = nrq.ctx.Fields + if len(nrq.ctx.Fields) > 0 { + _spec.Unique = nrq.ctx.Unique != nil && *nrq.ctx.Unique + } + return sqlgraph.CountNodes(ctx, nrq.driver, _spec) +} + +func (nrq *NodeReviewQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(nodereview.Table, nodereview.Columns, sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID)) + _spec.From = nrq.sql + if unique := nrq.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if nrq.path != nil { + _spec.Unique = true + } + if fields := nrq.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, nodereview.FieldID) + for i := range fields { + if fields[i] != nodereview.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + if nrq.withUser != nil { + _spec.Node.AddColumnOnce(nodereview.FieldUserID) + } + if nrq.withNode != nil { + _spec.Node.AddColumnOnce(nodereview.FieldNodeID) + } + } + if ps := nrq.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := nrq.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := nrq.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := nrq.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (nrq *NodeReviewQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(nrq.driver.Dialect()) + t1 := builder.Table(nodereview.Table) + columns := nrq.ctx.Fields + if len(columns) == 0 { + columns = nodereview.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if nrq.sql != nil { + selector = nrq.sql + selector.Select(selector.Columns(columns...)...) + } + if nrq.ctx.Unique != nil && *nrq.ctx.Unique { + selector.Distinct() + } + for _, m := range nrq.modifiers { + m(selector) + } + for _, p := range nrq.predicates { + p(selector) + } + for _, p := range nrq.order { + p(selector) + } + if offset := nrq.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := nrq.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (nrq *NodeReviewQuery) ForUpdate(opts ...sql.LockOption) *NodeReviewQuery { + if nrq.driver.Dialect() == dialect.Postgres { + nrq.Unique(false) + } + nrq.modifiers = append(nrq.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return nrq +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (nrq *NodeReviewQuery) ForShare(opts ...sql.LockOption) *NodeReviewQuery { + if nrq.driver.Dialect() == dialect.Postgres { + nrq.Unique(false) + } + nrq.modifiers = append(nrq.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return nrq +} + +// Modify adds a query modifier for attaching custom logic to queries. +func (nrq *NodeReviewQuery) Modify(modifiers ...func(s *sql.Selector)) *NodeReviewSelect { + nrq.modifiers = append(nrq.modifiers, modifiers...) + return nrq.Select() +} + +// NodeReviewGroupBy is the group-by builder for NodeReview entities. +type NodeReviewGroupBy struct { + selector + build *NodeReviewQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (nrgb *NodeReviewGroupBy) Aggregate(fns ...AggregateFunc) *NodeReviewGroupBy { + nrgb.fns = append(nrgb.fns, fns...) + return nrgb +} + +// Scan applies the selector query and scans the result into the given value. +func (nrgb *NodeReviewGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, nrgb.build.ctx, "GroupBy") + if err := nrgb.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*NodeReviewQuery, *NodeReviewGroupBy](ctx, nrgb.build, nrgb, nrgb.build.inters, v) +} + +func (nrgb *NodeReviewGroupBy) sqlScan(ctx context.Context, root *NodeReviewQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(nrgb.fns)) + for _, fn := range nrgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*nrgb.flds)+len(nrgb.fns)) + for _, f := range *nrgb.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*nrgb.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := nrgb.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// NodeReviewSelect is the builder for selecting fields of NodeReview entities. +type NodeReviewSelect struct { + *NodeReviewQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (nrs *NodeReviewSelect) Aggregate(fns ...AggregateFunc) *NodeReviewSelect { + nrs.fns = append(nrs.fns, fns...) + return nrs +} + +// Scan applies the selector query and scans the result into the given value. +func (nrs *NodeReviewSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, nrs.ctx, "Select") + if err := nrs.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*NodeReviewQuery, *NodeReviewSelect](ctx, nrs.NodeReviewQuery, nrs, nrs.inters, v) +} + +func (nrs *NodeReviewSelect) sqlScan(ctx context.Context, root *NodeReviewQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(nrs.fns)) + for _, fn := range nrs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*nrs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := nrs.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// Modify adds a query modifier for attaching custom logic to queries. +func (nrs *NodeReviewSelect) Modify(modifiers ...func(s *sql.Selector)) *NodeReviewSelect { + nrs.modifiers = append(nrs.modifiers, modifiers...) + return nrs +} diff --git a/ent/nodereview_update.go b/ent/nodereview_update.go new file mode 100644 index 0000000..352461c --- /dev/null +++ b/ent/nodereview_update.go @@ -0,0 +1,491 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "registry-backend/ent/node" + "registry-backend/ent/nodereview" + "registry-backend/ent/predicate" + "registry-backend/ent/user" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" +) + +// NodeReviewUpdate is the builder for updating NodeReview entities. +type NodeReviewUpdate struct { + config + hooks []Hook + mutation *NodeReviewMutation + modifiers []func(*sql.UpdateBuilder) +} + +// Where appends a list predicates to the NodeReviewUpdate builder. +func (nru *NodeReviewUpdate) Where(ps ...predicate.NodeReview) *NodeReviewUpdate { + nru.mutation.Where(ps...) + return nru +} + +// SetNodeID sets the "node_id" field. +func (nru *NodeReviewUpdate) SetNodeID(s string) *NodeReviewUpdate { + nru.mutation.SetNodeID(s) + return nru +} + +// SetNillableNodeID sets the "node_id" field if the given value is not nil. +func (nru *NodeReviewUpdate) SetNillableNodeID(s *string) *NodeReviewUpdate { + if s != nil { + nru.SetNodeID(*s) + } + return nru +} + +// SetUserID sets the "user_id" field. +func (nru *NodeReviewUpdate) SetUserID(s string) *NodeReviewUpdate { + nru.mutation.SetUserID(s) + return nru +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (nru *NodeReviewUpdate) SetNillableUserID(s *string) *NodeReviewUpdate { + if s != nil { + nru.SetUserID(*s) + } + return nru +} + +// SetStar sets the "star" field. +func (nru *NodeReviewUpdate) SetStar(i int) *NodeReviewUpdate { + nru.mutation.ResetStar() + nru.mutation.SetStar(i) + return nru +} + +// SetNillableStar sets the "star" field if the given value is not nil. +func (nru *NodeReviewUpdate) SetNillableStar(i *int) *NodeReviewUpdate { + if i != nil { + nru.SetStar(*i) + } + return nru +} + +// AddStar adds i to the "star" field. +func (nru *NodeReviewUpdate) AddStar(i int) *NodeReviewUpdate { + nru.mutation.AddStar(i) + return nru +} + +// SetUser sets the "user" edge to the User entity. +func (nru *NodeReviewUpdate) SetUser(u *User) *NodeReviewUpdate { + return nru.SetUserID(u.ID) +} + +// SetNode sets the "node" edge to the Node entity. +func (nru *NodeReviewUpdate) SetNode(n *Node) *NodeReviewUpdate { + return nru.SetNodeID(n.ID) +} + +// Mutation returns the NodeReviewMutation object of the builder. +func (nru *NodeReviewUpdate) Mutation() *NodeReviewMutation { + return nru.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (nru *NodeReviewUpdate) ClearUser() *NodeReviewUpdate { + nru.mutation.ClearUser() + return nru +} + +// ClearNode clears the "node" edge to the Node entity. +func (nru *NodeReviewUpdate) ClearNode() *NodeReviewUpdate { + nru.mutation.ClearNode() + return nru +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (nru *NodeReviewUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, nru.sqlSave, nru.mutation, nru.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (nru *NodeReviewUpdate) SaveX(ctx context.Context) int { + affected, err := nru.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (nru *NodeReviewUpdate) Exec(ctx context.Context) error { + _, err := nru.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nru *NodeReviewUpdate) ExecX(ctx context.Context) { + if err := nru.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (nru *NodeReviewUpdate) check() error { + if _, ok := nru.mutation.UserID(); nru.mutation.UserCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "NodeReview.user"`) + } + if _, ok := nru.mutation.NodeID(); nru.mutation.NodeCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "NodeReview.node"`) + } + return nil +} + +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (nru *NodeReviewUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *NodeReviewUpdate { + nru.modifiers = append(nru.modifiers, modifiers...) + return nru +} + +func (nru *NodeReviewUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := nru.check(); err != nil { + return n, err + } + _spec := sqlgraph.NewUpdateSpec(nodereview.Table, nodereview.Columns, sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID)) + if ps := nru.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := nru.mutation.Star(); ok { + _spec.SetField(nodereview.FieldStar, field.TypeInt, value) + } + if value, ok := nru.mutation.AddedStar(); ok { + _spec.AddField(nodereview.FieldStar, field.TypeInt, value) + } + if nru.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.UserTable, + Columns: []string{nodereview.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nru.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.UserTable, + Columns: []string{nodereview.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if nru.mutation.NodeCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.NodeTable, + Columns: []string{nodereview.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nru.mutation.NodeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.NodeTable, + Columns: []string{nodereview.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _spec.AddModifiers(nru.modifiers...) + if n, err = sqlgraph.UpdateNodes(ctx, nru.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{nodereview.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + nru.mutation.done = true + return n, nil +} + +// NodeReviewUpdateOne is the builder for updating a single NodeReview entity. +type NodeReviewUpdateOne struct { + config + fields []string + hooks []Hook + mutation *NodeReviewMutation + modifiers []func(*sql.UpdateBuilder) +} + +// SetNodeID sets the "node_id" field. +func (nruo *NodeReviewUpdateOne) SetNodeID(s string) *NodeReviewUpdateOne { + nruo.mutation.SetNodeID(s) + return nruo +} + +// SetNillableNodeID sets the "node_id" field if the given value is not nil. +func (nruo *NodeReviewUpdateOne) SetNillableNodeID(s *string) *NodeReviewUpdateOne { + if s != nil { + nruo.SetNodeID(*s) + } + return nruo +} + +// SetUserID sets the "user_id" field. +func (nruo *NodeReviewUpdateOne) SetUserID(s string) *NodeReviewUpdateOne { + nruo.mutation.SetUserID(s) + return nruo +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (nruo *NodeReviewUpdateOne) SetNillableUserID(s *string) *NodeReviewUpdateOne { + if s != nil { + nruo.SetUserID(*s) + } + return nruo +} + +// SetStar sets the "star" field. +func (nruo *NodeReviewUpdateOne) SetStar(i int) *NodeReviewUpdateOne { + nruo.mutation.ResetStar() + nruo.mutation.SetStar(i) + return nruo +} + +// SetNillableStar sets the "star" field if the given value is not nil. +func (nruo *NodeReviewUpdateOne) SetNillableStar(i *int) *NodeReviewUpdateOne { + if i != nil { + nruo.SetStar(*i) + } + return nruo +} + +// AddStar adds i to the "star" field. +func (nruo *NodeReviewUpdateOne) AddStar(i int) *NodeReviewUpdateOne { + nruo.mutation.AddStar(i) + return nruo +} + +// SetUser sets the "user" edge to the User entity. +func (nruo *NodeReviewUpdateOne) SetUser(u *User) *NodeReviewUpdateOne { + return nruo.SetUserID(u.ID) +} + +// SetNode sets the "node" edge to the Node entity. +func (nruo *NodeReviewUpdateOne) SetNode(n *Node) *NodeReviewUpdateOne { + return nruo.SetNodeID(n.ID) +} + +// Mutation returns the NodeReviewMutation object of the builder. +func (nruo *NodeReviewUpdateOne) Mutation() *NodeReviewMutation { + return nruo.mutation +} + +// ClearUser clears the "user" edge to the User entity. +func (nruo *NodeReviewUpdateOne) ClearUser() *NodeReviewUpdateOne { + nruo.mutation.ClearUser() + return nruo +} + +// ClearNode clears the "node" edge to the Node entity. +func (nruo *NodeReviewUpdateOne) ClearNode() *NodeReviewUpdateOne { + nruo.mutation.ClearNode() + return nruo +} + +// Where appends a list predicates to the NodeReviewUpdate builder. +func (nruo *NodeReviewUpdateOne) Where(ps ...predicate.NodeReview) *NodeReviewUpdateOne { + nruo.mutation.Where(ps...) + return nruo +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (nruo *NodeReviewUpdateOne) Select(field string, fields ...string) *NodeReviewUpdateOne { + nruo.fields = append([]string{field}, fields...) + return nruo +} + +// Save executes the query and returns the updated NodeReview entity. +func (nruo *NodeReviewUpdateOne) Save(ctx context.Context) (*NodeReview, error) { + return withHooks(ctx, nruo.sqlSave, nruo.mutation, nruo.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (nruo *NodeReviewUpdateOne) SaveX(ctx context.Context) *NodeReview { + node, err := nruo.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (nruo *NodeReviewUpdateOne) Exec(ctx context.Context) error { + _, err := nruo.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (nruo *NodeReviewUpdateOne) ExecX(ctx context.Context) { + if err := nruo.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (nruo *NodeReviewUpdateOne) check() error { + if _, ok := nruo.mutation.UserID(); nruo.mutation.UserCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "NodeReview.user"`) + } + if _, ok := nruo.mutation.NodeID(); nruo.mutation.NodeCleared() && !ok { + return errors.New(`ent: clearing a required unique edge "NodeReview.node"`) + } + return nil +} + +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (nruo *NodeReviewUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *NodeReviewUpdateOne { + nruo.modifiers = append(nruo.modifiers, modifiers...) + return nruo +} + +func (nruo *NodeReviewUpdateOne) sqlSave(ctx context.Context) (_node *NodeReview, err error) { + if err := nruo.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(nodereview.Table, nodereview.Columns, sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID)) + id, ok := nruo.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "NodeReview.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := nruo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, nodereview.FieldID) + for _, f := range fields { + if !nodereview.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != nodereview.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := nruo.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := nruo.mutation.Star(); ok { + _spec.SetField(nodereview.FieldStar, field.TypeInt, value) + } + if value, ok := nruo.mutation.AddedStar(); ok { + _spec.AddField(nodereview.FieldStar, field.TypeInt, value) + } + if nruo.mutation.UserCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.UserTable, + Columns: []string{nodereview.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nruo.mutation.UserIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.UserTable, + Columns: []string{nodereview.UserColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(user.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + if nruo.mutation.NodeCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.NodeTable, + Columns: []string{nodereview.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeString), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := nruo.mutation.NodeIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.M2O, + Inverse: true, + Table: nodereview.NodeTable, + Columns: []string{nodereview.NodeColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(node.FieldID, field.TypeString), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _spec.AddModifiers(nruo.modifiers...) + _node = &NodeReview{config: nruo.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, nruo.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{nodereview.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + nruo.mutation.done = true + return _node, nil +} diff --git a/ent/nodeversion.go b/ent/nodeversion.go index 0222910..63d13cc 100644 --- a/ent/nodeversion.go +++ b/ent/nodeversion.go @@ -7,6 +7,7 @@ import ( "fmt" "registry-backend/ent/node" "registry-backend/ent/nodeversion" + "registry-backend/ent/schema" "registry-backend/ent/storagefile" "strings" "time" @@ -35,6 +36,10 @@ type NodeVersion struct { PipDependencies []string `json:"pip_dependencies,omitempty"` // Deprecated holds the value of the "deprecated" field. Deprecated bool `json:"deprecated,omitempty"` + // Status holds the value of the "status" field. + Status schema.NodeVersionStatus `json:"status,omitempty"` + // Give a reason for the status change. Eg. 'Banned due to security vulnerability' + StatusReason string `json:"status_reason,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the NodeVersionQuery when eager-loading is set. Edges NodeVersionEdges `json:"edges"` @@ -84,7 +89,7 @@ func (*NodeVersion) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case nodeversion.FieldDeprecated: values[i] = new(sql.NullBool) - case nodeversion.FieldNodeID, nodeversion.FieldVersion, nodeversion.FieldChangelog: + case nodeversion.FieldNodeID, nodeversion.FieldVersion, nodeversion.FieldChangelog, nodeversion.FieldStatus, nodeversion.FieldStatusReason: values[i] = new(sql.NullString) case nodeversion.FieldCreateTime, nodeversion.FieldUpdateTime: values[i] = new(sql.NullTime) @@ -157,6 +162,18 @@ func (nv *NodeVersion) assignValues(columns []string, values []any) error { } else if value.Valid { nv.Deprecated = value.Bool } + case nodeversion.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + nv.Status = schema.NodeVersionStatus(value.String) + } + case nodeversion.FieldStatusReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status_reason", values[i]) + } else if value.Valid { + nv.StatusReason = value.String + } case nodeversion.ForeignKeys[0]: if value, ok := values[i].(*sql.NullScanner); !ok { return fmt.Errorf("unexpected type %T for field node_version_storage_file", values[i]) @@ -230,6 +247,12 @@ func (nv *NodeVersion) String() string { builder.WriteString(", ") builder.WriteString("deprecated=") builder.WriteString(fmt.Sprintf("%v", nv.Deprecated)) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", nv.Status)) + builder.WriteString(", ") + builder.WriteString("status_reason=") + builder.WriteString(nv.StatusReason) builder.WriteByte(')') return builder.String() } diff --git a/ent/nodeversion/nodeversion.go b/ent/nodeversion/nodeversion.go index 3507958..bb5b59d 100644 --- a/ent/nodeversion/nodeversion.go +++ b/ent/nodeversion/nodeversion.go @@ -3,6 +3,8 @@ package nodeversion import ( + "fmt" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -29,6 +31,10 @@ const ( FieldPipDependencies = "pip_dependencies" // FieldDeprecated holds the string denoting the deprecated field in the database. FieldDeprecated = "deprecated" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldStatusReason holds the string denoting the status_reason field in the database. + FieldStatusReason = "status_reason" // EdgeNode holds the string denoting the node edge name in mutations. EdgeNode = "node" // EdgeStorageFile holds the string denoting the storage_file edge name in mutations. @@ -61,6 +67,8 @@ var Columns = []string{ FieldChangelog, FieldPipDependencies, FieldDeprecated, + FieldStatus, + FieldStatusReason, } // ForeignKeys holds the SQL foreign-keys that are owned by the "node_versions" @@ -93,10 +101,24 @@ var ( UpdateDefaultUpdateTime func() time.Time // DefaultDeprecated holds the default value on creation for the "deprecated" field. DefaultDeprecated bool + // DefaultStatusReason holds the default value on creation for the "status_reason" field. + DefaultStatusReason string // DefaultID holds the default value on creation for the "id" field. DefaultID func() uuid.UUID ) +const DefaultStatus schema.NodeVersionStatus = "pending" + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s schema.NodeVersionStatus) error { + switch s { + case "active", "banned", "deleted", "pending", "flagged": + return nil + default: + return fmt.Errorf("nodeversion: invalid enum value for status field: %q", s) + } +} + // OrderOption defines the ordering options for the NodeVersion queries. type OrderOption func(*sql.Selector) @@ -135,6 +157,16 @@ func ByDeprecated(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDeprecated, opts...).ToFunc() } +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByStatusReason orders the results by the status_reason field. +func ByStatusReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatusReason, opts...).ToFunc() +} + // ByNodeField orders the results by node field. func ByNodeField(field string, opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/ent/nodeversion/where.go b/ent/nodeversion/where.go index 683c2eb..65a2871 100644 --- a/ent/nodeversion/where.go +++ b/ent/nodeversion/where.go @@ -4,6 +4,7 @@ package nodeversion import ( "registry-backend/ent/predicate" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -86,6 +87,11 @@ func Deprecated(v bool) predicate.NodeVersion { return predicate.NodeVersion(sql.FieldEQ(FieldDeprecated, v)) } +// StatusReason applies equality check predicate on the "status_reason" field. It's identical to StatusReasonEQ. +func StatusReason(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldEQ(FieldStatusReason, v)) +} + // CreateTimeEQ applies the EQ predicate on the "create_time" field. func CreateTimeEQ(v time.Time) predicate.NodeVersion { return predicate.NodeVersion(sql.FieldEQ(FieldCreateTime, v)) @@ -381,6 +387,101 @@ func DeprecatedNEQ(v bool) predicate.NodeVersion { return predicate.NodeVersion(sql.FieldNEQ(FieldDeprecated, v)) } +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v schema.NodeVersionStatus) predicate.NodeVersion { + vc := v + return predicate.NodeVersion(sql.FieldEQ(FieldStatus, vc)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v schema.NodeVersionStatus) predicate.NodeVersion { + vc := v + return predicate.NodeVersion(sql.FieldNEQ(FieldStatus, vc)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...schema.NodeVersionStatus) predicate.NodeVersion { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.NodeVersion(sql.FieldIn(FieldStatus, v...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...schema.NodeVersionStatus) predicate.NodeVersion { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.NodeVersion(sql.FieldNotIn(FieldStatus, v...)) +} + +// StatusReasonEQ applies the EQ predicate on the "status_reason" field. +func StatusReasonEQ(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldEQ(FieldStatusReason, v)) +} + +// StatusReasonNEQ applies the NEQ predicate on the "status_reason" field. +func StatusReasonNEQ(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldNEQ(FieldStatusReason, v)) +} + +// StatusReasonIn applies the In predicate on the "status_reason" field. +func StatusReasonIn(vs ...string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldIn(FieldStatusReason, vs...)) +} + +// StatusReasonNotIn applies the NotIn predicate on the "status_reason" field. +func StatusReasonNotIn(vs ...string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldNotIn(FieldStatusReason, vs...)) +} + +// StatusReasonGT applies the GT predicate on the "status_reason" field. +func StatusReasonGT(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldGT(FieldStatusReason, v)) +} + +// StatusReasonGTE applies the GTE predicate on the "status_reason" field. +func StatusReasonGTE(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldGTE(FieldStatusReason, v)) +} + +// StatusReasonLT applies the LT predicate on the "status_reason" field. +func StatusReasonLT(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldLT(FieldStatusReason, v)) +} + +// StatusReasonLTE applies the LTE predicate on the "status_reason" field. +func StatusReasonLTE(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldLTE(FieldStatusReason, v)) +} + +// StatusReasonContains applies the Contains predicate on the "status_reason" field. +func StatusReasonContains(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldContains(FieldStatusReason, v)) +} + +// StatusReasonHasPrefix applies the HasPrefix predicate on the "status_reason" field. +func StatusReasonHasPrefix(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldHasPrefix(FieldStatusReason, v)) +} + +// StatusReasonHasSuffix applies the HasSuffix predicate on the "status_reason" field. +func StatusReasonHasSuffix(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldHasSuffix(FieldStatusReason, v)) +} + +// StatusReasonEqualFold applies the EqualFold predicate on the "status_reason" field. +func StatusReasonEqualFold(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldEqualFold(FieldStatusReason, v)) +} + +// StatusReasonContainsFold applies the ContainsFold predicate on the "status_reason" field. +func StatusReasonContainsFold(v string) predicate.NodeVersion { + return predicate.NodeVersion(sql.FieldContainsFold(FieldStatusReason, v)) +} + // HasNode applies the HasEdge predicate on the "node" edge. func HasNode() predicate.NodeVersion { return predicate.NodeVersion(func(s *sql.Selector) { diff --git a/ent/nodeversion_create.go b/ent/nodeversion_create.go index 611fe2e..0b7b17d 100644 --- a/ent/nodeversion_create.go +++ b/ent/nodeversion_create.go @@ -8,6 +8,7 @@ import ( "fmt" "registry-backend/ent/node" "registry-backend/ent/nodeversion" + "registry-backend/ent/schema" "registry-backend/ent/storagefile" "time" @@ -100,6 +101,34 @@ func (nvc *NodeVersionCreate) SetNillableDeprecated(b *bool) *NodeVersionCreate return nvc } +// SetStatus sets the "status" field. +func (nvc *NodeVersionCreate) SetStatus(svs schema.NodeVersionStatus) *NodeVersionCreate { + nvc.mutation.SetStatus(svs) + return nvc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nvc *NodeVersionCreate) SetNillableStatus(svs *schema.NodeVersionStatus) *NodeVersionCreate { + if svs != nil { + nvc.SetStatus(*svs) + } + return nvc +} + +// SetStatusReason sets the "status_reason" field. +func (nvc *NodeVersionCreate) SetStatusReason(s string) *NodeVersionCreate { + nvc.mutation.SetStatusReason(s) + return nvc +} + +// SetNillableStatusReason sets the "status_reason" field if the given value is not nil. +func (nvc *NodeVersionCreate) SetNillableStatusReason(s *string) *NodeVersionCreate { + if s != nil { + nvc.SetStatusReason(*s) + } + return nvc +} + // SetID sets the "id" field. func (nvc *NodeVersionCreate) SetID(u uuid.UUID) *NodeVersionCreate { nvc.mutation.SetID(u) @@ -185,6 +214,14 @@ func (nvc *NodeVersionCreate) defaults() { v := nodeversion.DefaultDeprecated nvc.mutation.SetDeprecated(v) } + if _, ok := nvc.mutation.Status(); !ok { + v := nodeversion.DefaultStatus + nvc.mutation.SetStatus(v) + } + if _, ok := nvc.mutation.StatusReason(); !ok { + v := nodeversion.DefaultStatusReason + nvc.mutation.SetStatusReason(v) + } if _, ok := nvc.mutation.ID(); !ok { v := nodeversion.DefaultID() nvc.mutation.SetID(v) @@ -211,6 +248,17 @@ func (nvc *NodeVersionCreate) check() error { if _, ok := nvc.mutation.Deprecated(); !ok { return &ValidationError{Name: "deprecated", err: errors.New(`ent: missing required field "NodeVersion.deprecated"`)} } + if _, ok := nvc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "NodeVersion.status"`)} + } + if v, ok := nvc.mutation.Status(); ok { + if err := nodeversion.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "NodeVersion.status": %w`, err)} + } + } + if _, ok := nvc.mutation.StatusReason(); !ok { + return &ValidationError{Name: "status_reason", err: errors.New(`ent: missing required field "NodeVersion.status_reason"`)} + } if _, ok := nvc.mutation.NodeID(); !ok { return &ValidationError{Name: "node", err: errors.New(`ent: missing required edge "NodeVersion.node"`)} } @@ -274,6 +322,14 @@ func (nvc *NodeVersionCreate) createSpec() (*NodeVersion, *sqlgraph.CreateSpec) _spec.SetField(nodeversion.FieldDeprecated, field.TypeBool, value) _node.Deprecated = value } + if value, ok := nvc.mutation.Status(); ok { + _spec.SetField(nodeversion.FieldStatus, field.TypeEnum, value) + _node.Status = value + } + if value, ok := nvc.mutation.StatusReason(); ok { + _spec.SetField(nodeversion.FieldStatusReason, field.TypeString, value) + _node.StatusReason = value + } if nodes := nvc.mutation.NodeIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -438,6 +494,30 @@ func (u *NodeVersionUpsert) UpdateDeprecated() *NodeVersionUpsert { return u } +// SetStatus sets the "status" field. +func (u *NodeVersionUpsert) SetStatus(v schema.NodeVersionStatus) *NodeVersionUpsert { + u.Set(nodeversion.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeVersionUpsert) UpdateStatus() *NodeVersionUpsert { + u.SetExcluded(nodeversion.FieldStatus) + return u +} + +// SetStatusReason sets the "status_reason" field. +func (u *NodeVersionUpsert) SetStatusReason(v string) *NodeVersionUpsert { + u.Set(nodeversion.FieldStatusReason, v) + return u +} + +// UpdateStatusReason sets the "status_reason" field to the value that was provided on create. +func (u *NodeVersionUpsert) UpdateStatusReason() *NodeVersionUpsert { + u.SetExcluded(nodeversion.FieldStatusReason) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -580,6 +660,34 @@ func (u *NodeVersionUpsertOne) UpdateDeprecated() *NodeVersionUpsertOne { }) } +// SetStatus sets the "status" field. +func (u *NodeVersionUpsertOne) SetStatus(v schema.NodeVersionStatus) *NodeVersionUpsertOne { + return u.Update(func(s *NodeVersionUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeVersionUpsertOne) UpdateStatus() *NodeVersionUpsertOne { + return u.Update(func(s *NodeVersionUpsert) { + s.UpdateStatus() + }) +} + +// SetStatusReason sets the "status_reason" field. +func (u *NodeVersionUpsertOne) SetStatusReason(v string) *NodeVersionUpsertOne { + return u.Update(func(s *NodeVersionUpsert) { + s.SetStatusReason(v) + }) +} + +// UpdateStatusReason sets the "status_reason" field to the value that was provided on create. +func (u *NodeVersionUpsertOne) UpdateStatusReason() *NodeVersionUpsertOne { + return u.Update(func(s *NodeVersionUpsert) { + s.UpdateStatusReason() + }) +} + // Exec executes the query. func (u *NodeVersionUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -889,6 +997,34 @@ func (u *NodeVersionUpsertBulk) UpdateDeprecated() *NodeVersionUpsertBulk { }) } +// SetStatus sets the "status" field. +func (u *NodeVersionUpsertBulk) SetStatus(v schema.NodeVersionStatus) *NodeVersionUpsertBulk { + return u.Update(func(s *NodeVersionUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *NodeVersionUpsertBulk) UpdateStatus() *NodeVersionUpsertBulk { + return u.Update(func(s *NodeVersionUpsert) { + s.UpdateStatus() + }) +} + +// SetStatusReason sets the "status_reason" field. +func (u *NodeVersionUpsertBulk) SetStatusReason(v string) *NodeVersionUpsertBulk { + return u.Update(func(s *NodeVersionUpsert) { + s.SetStatusReason(v) + }) +} + +// UpdateStatusReason sets the "status_reason" field to the value that was provided on create. +func (u *NodeVersionUpsertBulk) UpdateStatusReason() *NodeVersionUpsertBulk { + return u.Update(func(s *NodeVersionUpsert) { + s.UpdateStatusReason() + }) +} + // Exec executes the query. func (u *NodeVersionUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/ent/nodeversion_query.go b/ent/nodeversion_query.go index 7f0cb2c..05074c1 100644 --- a/ent/nodeversion_query.go +++ b/ent/nodeversion_query.go @@ -635,6 +635,12 @@ func (nvq *NodeVersionQuery) ForShare(opts ...sql.LockOption) *NodeVersionQuery return nvq } +// Modify adds a query modifier for attaching custom logic to queries. +func (nvq *NodeVersionQuery) Modify(modifiers ...func(s *sql.Selector)) *NodeVersionSelect { + nvq.modifiers = append(nvq.modifiers, modifiers...) + return nvq.Select() +} + // NodeVersionGroupBy is the group-by builder for NodeVersion entities. type NodeVersionGroupBy struct { selector @@ -724,3 +730,9 @@ func (nvs *NodeVersionSelect) sqlScan(ctx context.Context, root *NodeVersionQuer defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (nvs *NodeVersionSelect) Modify(modifiers ...func(s *sql.Selector)) *NodeVersionSelect { + nvs.modifiers = append(nvs.modifiers, modifiers...) + return nvs +} diff --git a/ent/nodeversion_update.go b/ent/nodeversion_update.go index 65dd779..21fcba9 100644 --- a/ent/nodeversion_update.go +++ b/ent/nodeversion_update.go @@ -9,6 +9,7 @@ import ( "registry-backend/ent/node" "registry-backend/ent/nodeversion" "registry-backend/ent/predicate" + "registry-backend/ent/schema" "registry-backend/ent/storagefile" "time" @@ -22,8 +23,9 @@ import ( // NodeVersionUpdate is the builder for updating NodeVersion entities. type NodeVersionUpdate struct { config - hooks []Hook - mutation *NodeVersionMutation + hooks []Hook + mutation *NodeVersionMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the NodeVersionUpdate builder. @@ -112,6 +114,34 @@ func (nvu *NodeVersionUpdate) SetNillableDeprecated(b *bool) *NodeVersionUpdate return nvu } +// SetStatus sets the "status" field. +func (nvu *NodeVersionUpdate) SetStatus(svs schema.NodeVersionStatus) *NodeVersionUpdate { + nvu.mutation.SetStatus(svs) + return nvu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nvu *NodeVersionUpdate) SetNillableStatus(svs *schema.NodeVersionStatus) *NodeVersionUpdate { + if svs != nil { + nvu.SetStatus(*svs) + } + return nvu +} + +// SetStatusReason sets the "status_reason" field. +func (nvu *NodeVersionUpdate) SetStatusReason(s string) *NodeVersionUpdate { + nvu.mutation.SetStatusReason(s) + return nvu +} + +// SetNillableStatusReason sets the "status_reason" field if the given value is not nil. +func (nvu *NodeVersionUpdate) SetNillableStatusReason(s *string) *NodeVersionUpdate { + if s != nil { + nvu.SetStatusReason(*s) + } + return nvu +} + // SetNode sets the "node" edge to the Node entity. func (nvu *NodeVersionUpdate) SetNode(n *Node) *NodeVersionUpdate { return nvu.SetNodeID(n.ID) @@ -191,12 +221,23 @@ func (nvu *NodeVersionUpdate) defaults() { // check runs all checks and user-defined validators on the builder. func (nvu *NodeVersionUpdate) check() error { + if v, ok := nvu.mutation.Status(); ok { + if err := nodeversion.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "NodeVersion.status": %w`, err)} + } + } if _, ok := nvu.mutation.NodeID(); nvu.mutation.NodeCleared() && !ok { return errors.New(`ent: clearing a required unique edge "NodeVersion.node"`) } return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (nvu *NodeVersionUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *NodeVersionUpdate { + nvu.modifiers = append(nvu.modifiers, modifiers...) + return nvu +} + func (nvu *NodeVersionUpdate) sqlSave(ctx context.Context) (n int, err error) { if err := nvu.check(); err != nil { return n, err @@ -232,6 +273,12 @@ func (nvu *NodeVersionUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := nvu.mutation.Deprecated(); ok { _spec.SetField(nodeversion.FieldDeprecated, field.TypeBool, value) } + if value, ok := nvu.mutation.Status(); ok { + _spec.SetField(nodeversion.FieldStatus, field.TypeEnum, value) + } + if value, ok := nvu.mutation.StatusReason(); ok { + _spec.SetField(nodeversion.FieldStatusReason, field.TypeString, value) + } if nvu.mutation.NodeCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -290,6 +337,7 @@ func (nvu *NodeVersionUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(nvu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, nvu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{nodeversion.Label} @@ -305,9 +353,10 @@ func (nvu *NodeVersionUpdate) sqlSave(ctx context.Context) (n int, err error) { // NodeVersionUpdateOne is the builder for updating a single NodeVersion entity. type NodeVersionUpdateOne struct { config - fields []string - hooks []Hook - mutation *NodeVersionMutation + fields []string + hooks []Hook + mutation *NodeVersionMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -390,6 +439,34 @@ func (nvuo *NodeVersionUpdateOne) SetNillableDeprecated(b *bool) *NodeVersionUpd return nvuo } +// SetStatus sets the "status" field. +func (nvuo *NodeVersionUpdateOne) SetStatus(svs schema.NodeVersionStatus) *NodeVersionUpdateOne { + nvuo.mutation.SetStatus(svs) + return nvuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (nvuo *NodeVersionUpdateOne) SetNillableStatus(svs *schema.NodeVersionStatus) *NodeVersionUpdateOne { + if svs != nil { + nvuo.SetStatus(*svs) + } + return nvuo +} + +// SetStatusReason sets the "status_reason" field. +func (nvuo *NodeVersionUpdateOne) SetStatusReason(s string) *NodeVersionUpdateOne { + nvuo.mutation.SetStatusReason(s) + return nvuo +} + +// SetNillableStatusReason sets the "status_reason" field if the given value is not nil. +func (nvuo *NodeVersionUpdateOne) SetNillableStatusReason(s *string) *NodeVersionUpdateOne { + if s != nil { + nvuo.SetStatusReason(*s) + } + return nvuo +} + // SetNode sets the "node" edge to the Node entity. func (nvuo *NodeVersionUpdateOne) SetNode(n *Node) *NodeVersionUpdateOne { return nvuo.SetNodeID(n.ID) @@ -482,12 +559,23 @@ func (nvuo *NodeVersionUpdateOne) defaults() { // check runs all checks and user-defined validators on the builder. func (nvuo *NodeVersionUpdateOne) check() error { + if v, ok := nvuo.mutation.Status(); ok { + if err := nodeversion.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "NodeVersion.status": %w`, err)} + } + } if _, ok := nvuo.mutation.NodeID(); nvuo.mutation.NodeCleared() && !ok { return errors.New(`ent: clearing a required unique edge "NodeVersion.node"`) } return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (nvuo *NodeVersionUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *NodeVersionUpdateOne { + nvuo.modifiers = append(nvuo.modifiers, modifiers...) + return nvuo +} + func (nvuo *NodeVersionUpdateOne) sqlSave(ctx context.Context) (_node *NodeVersion, err error) { if err := nvuo.check(); err != nil { return _node, err @@ -540,6 +628,12 @@ func (nvuo *NodeVersionUpdateOne) sqlSave(ctx context.Context) (_node *NodeVersi if value, ok := nvuo.mutation.Deprecated(); ok { _spec.SetField(nodeversion.FieldDeprecated, field.TypeBool, value) } + if value, ok := nvuo.mutation.Status(); ok { + _spec.SetField(nodeversion.FieldStatus, field.TypeEnum, value) + } + if value, ok := nvuo.mutation.StatusReason(); ok { + _spec.SetField(nodeversion.FieldStatusReason, field.TypeString, value) + } if nvuo.mutation.NodeCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.M2O, @@ -598,6 +692,7 @@ func (nvuo *NodeVersionUpdateOne) sqlSave(ctx context.Context) (_node *NodeVersi } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(nvuo.modifiers...) _node = &NodeVersion{config: nvuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/personalaccesstoken_query.go b/ent/personalaccesstoken_query.go index 1bdce41..d0d6a39 100644 --- a/ent/personalaccesstoken_query.go +++ b/ent/personalaccesstoken_query.go @@ -552,6 +552,12 @@ func (patq *PersonalAccessTokenQuery) ForShare(opts ...sql.LockOption) *Personal return patq } +// Modify adds a query modifier for attaching custom logic to queries. +func (patq *PersonalAccessTokenQuery) Modify(modifiers ...func(s *sql.Selector)) *PersonalAccessTokenSelect { + patq.modifiers = append(patq.modifiers, modifiers...) + return patq.Select() +} + // PersonalAccessTokenGroupBy is the group-by builder for PersonalAccessToken entities. type PersonalAccessTokenGroupBy struct { selector @@ -641,3 +647,9 @@ func (pats *PersonalAccessTokenSelect) sqlScan(ctx context.Context, root *Person defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (pats *PersonalAccessTokenSelect) Modify(modifiers ...func(s *sql.Selector)) *PersonalAccessTokenSelect { + pats.modifiers = append(pats.modifiers, modifiers...) + return pats +} diff --git a/ent/personalaccesstoken_update.go b/ent/personalaccesstoken_update.go index 1ab2945..e41faf7 100644 --- a/ent/personalaccesstoken_update.go +++ b/ent/personalaccesstoken_update.go @@ -19,8 +19,9 @@ import ( // PersonalAccessTokenUpdate is the builder for updating PersonalAccessToken entities. type PersonalAccessTokenUpdate struct { config - hooks []Hook - mutation *PersonalAccessTokenMutation + hooks []Hook + mutation *PersonalAccessTokenMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the PersonalAccessTokenUpdate builder. @@ -151,6 +152,12 @@ func (patu *PersonalAccessTokenUpdate) check() error { return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (patu *PersonalAccessTokenUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *PersonalAccessTokenUpdate { + patu.modifiers = append(patu.modifiers, modifiers...) + return patu +} + func (patu *PersonalAccessTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { if err := patu.check(); err != nil { return n, err @@ -204,6 +211,7 @@ func (patu *PersonalAccessTokenUpdate) sqlSave(ctx context.Context) (n int, err } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(patu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, patu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{personalaccesstoken.Label} @@ -219,9 +227,10 @@ func (patu *PersonalAccessTokenUpdate) sqlSave(ctx context.Context) (n int, err // PersonalAccessTokenUpdateOne is the builder for updating a single PersonalAccessToken entity. type PersonalAccessTokenUpdateOne struct { config - fields []string - hooks []Hook - mutation *PersonalAccessTokenMutation + fields []string + hooks []Hook + mutation *PersonalAccessTokenMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -359,6 +368,12 @@ func (patuo *PersonalAccessTokenUpdateOne) check() error { return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (patuo *PersonalAccessTokenUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *PersonalAccessTokenUpdateOne { + patuo.modifiers = append(patuo.modifiers, modifiers...) + return patuo +} + func (patuo *PersonalAccessTokenUpdateOne) sqlSave(ctx context.Context) (_node *PersonalAccessToken, err error) { if err := patuo.check(); err != nil { return _node, err @@ -429,6 +444,7 @@ func (patuo *PersonalAccessTokenUpdateOne) sqlSave(ctx context.Context) (_node * } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(patuo.modifiers...) _node = &PersonalAccessToken{config: patuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/predicate/predicate.go b/ent/predicate/predicate.go index c29ba05..37bf309 100644 --- a/ent/predicate/predicate.go +++ b/ent/predicate/predicate.go @@ -15,6 +15,9 @@ type GitCommit func(*sql.Selector) // Node is the predicate function for node builders. type Node func(*sql.Selector) +// NodeReview is the predicate function for nodereview builders. +type NodeReview func(*sql.Selector) + // NodeVersion is the predicate function for nodeversion builders. type NodeVersion func(*sql.Selector) diff --git a/ent/publisher.go b/ent/publisher.go index 9c35eb8..f4b52db 100644 --- a/ent/publisher.go +++ b/ent/publisher.go @@ -5,6 +5,7 @@ package ent import ( "fmt" "registry-backend/ent/publisher" + "registry-backend/ent/schema" "strings" "time" @@ -34,6 +35,8 @@ type Publisher struct { SourceCodeRepo string `json:"source_code_repo,omitempty"` // LogoURL holds the value of the "logo_url" field. LogoURL string `json:"logo_url,omitempty"` + // Status holds the value of the "status" field. + Status schema.PublisherStatusType `json:"status,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the PublisherQuery when eager-loading is set. Edges PublisherEdges `json:"edges"` @@ -85,7 +88,7 @@ func (*Publisher) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case publisher.FieldID, publisher.FieldName, publisher.FieldDescription, publisher.FieldWebsite, publisher.FieldSupportEmail, publisher.FieldSourceCodeRepo, publisher.FieldLogoURL: + case publisher.FieldID, publisher.FieldName, publisher.FieldDescription, publisher.FieldWebsite, publisher.FieldSupportEmail, publisher.FieldSourceCodeRepo, publisher.FieldLogoURL, publisher.FieldStatus: values[i] = new(sql.NullString) case publisher.FieldCreateTime, publisher.FieldUpdateTime: values[i] = new(sql.NullTime) @@ -158,6 +161,12 @@ func (pu *Publisher) assignValues(columns []string, values []any) error { } else if value.Valid { pu.LogoURL = value.String } + case publisher.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + pu.Status = schema.PublisherStatusType(value.String) + } default: pu.selectValues.Set(columns[i], values[i]) } @@ -232,6 +241,9 @@ func (pu *Publisher) String() string { builder.WriteString(", ") builder.WriteString("logo_url=") builder.WriteString(pu.LogoURL) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", pu.Status)) builder.WriteByte(')') return builder.String() } diff --git a/ent/publisher/publisher.go b/ent/publisher/publisher.go index a905a76..15d6330 100644 --- a/ent/publisher/publisher.go +++ b/ent/publisher/publisher.go @@ -3,6 +3,8 @@ package publisher import ( + "fmt" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -30,6 +32,8 @@ const ( FieldSourceCodeRepo = "source_code_repo" // FieldLogoURL holds the string denoting the logo_url field in the database. FieldLogoURL = "logo_url" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" // EdgePublisherPermissions holds the string denoting the publisher_permissions edge name in mutations. EdgePublisherPermissions = "publisher_permissions" // EdgeNodes holds the string denoting the nodes edge name in mutations. @@ -72,6 +76,7 @@ var Columns = []string{ FieldSupportEmail, FieldSourceCodeRepo, FieldLogoURL, + FieldStatus, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -93,6 +98,18 @@ var ( UpdateDefaultUpdateTime func() time.Time ) +const DefaultStatus schema.PublisherStatusType = "ACTIVE" + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s schema.PublisherStatusType) error { + switch s { + case "ACTIVE", "BANNED": + return nil + default: + return fmt.Errorf("publisher: invalid enum value for status field: %q", s) + } +} + // OrderOption defines the ordering options for the Publisher queries. type OrderOption func(*sql.Selector) @@ -141,6 +158,11 @@ func ByLogoURL(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldLogoURL, opts...).ToFunc() } +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + // ByPublisherPermissionsCount orders the results by publisher_permissions count. func ByPublisherPermissionsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/ent/publisher/where.go b/ent/publisher/where.go index bb77b21..da8cad2 100644 --- a/ent/publisher/where.go +++ b/ent/publisher/where.go @@ -4,6 +4,7 @@ package publisher import ( "registry-backend/ent/predicate" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -625,6 +626,36 @@ func LogoURLContainsFold(v string) predicate.Publisher { return predicate.Publisher(sql.FieldContainsFold(FieldLogoURL, v)) } +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v schema.PublisherStatusType) predicate.Publisher { + vc := v + return predicate.Publisher(sql.FieldEQ(FieldStatus, vc)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v schema.PublisherStatusType) predicate.Publisher { + vc := v + return predicate.Publisher(sql.FieldNEQ(FieldStatus, vc)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...schema.PublisherStatusType) predicate.Publisher { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Publisher(sql.FieldIn(FieldStatus, v...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...schema.PublisherStatusType) predicate.Publisher { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.Publisher(sql.FieldNotIn(FieldStatus, v...)) +} + // HasPublisherPermissions applies the HasEdge predicate on the "publisher_permissions" edge. func HasPublisherPermissions() predicate.Publisher { return predicate.Publisher(func(s *sql.Selector) { diff --git a/ent/publisher_create.go b/ent/publisher_create.go index 77be4ed..24636c3 100644 --- a/ent/publisher_create.go +++ b/ent/publisher_create.go @@ -10,6 +10,7 @@ import ( "registry-backend/ent/personalaccesstoken" "registry-backend/ent/publisher" "registry-backend/ent/publisherpermission" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect" @@ -131,6 +132,20 @@ func (pc *PublisherCreate) SetNillableLogoURL(s *string) *PublisherCreate { return pc } +// SetStatus sets the "status" field. +func (pc *PublisherCreate) SetStatus(sst schema.PublisherStatusType) *PublisherCreate { + pc.mutation.SetStatus(sst) + return pc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (pc *PublisherCreate) SetNillableStatus(sst *schema.PublisherStatusType) *PublisherCreate { + if sst != nil { + pc.SetStatus(*sst) + } + return pc +} + // SetID sets the "id" field. func (pc *PublisherCreate) SetID(s string) *PublisherCreate { pc.mutation.SetID(s) @@ -225,6 +240,10 @@ func (pc *PublisherCreate) defaults() { v := publisher.DefaultUpdateTime() pc.mutation.SetUpdateTime(v) } + if _, ok := pc.mutation.Status(); !ok { + v := publisher.DefaultStatus + pc.mutation.SetStatus(v) + } } // check runs all checks and user-defined validators on the builder. @@ -238,6 +257,14 @@ func (pc *PublisherCreate) check() error { if _, ok := pc.mutation.Name(); !ok { return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "Publisher.name"`)} } + if _, ok := pc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Publisher.status"`)} + } + if v, ok := pc.mutation.Status(); ok { + if err := publisher.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Publisher.status": %w`, err)} + } + } return nil } @@ -306,6 +333,10 @@ func (pc *PublisherCreate) createSpec() (*Publisher, *sqlgraph.CreateSpec) { _spec.SetField(publisher.FieldLogoURL, field.TypeString, value) _node.LogoURL = value } + if value, ok := pc.mutation.Status(); ok { + _spec.SetField(publisher.FieldStatus, field.TypeEnum, value) + _node.Status = value + } if nodes := pc.mutation.PublisherPermissionsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -520,6 +551,18 @@ func (u *PublisherUpsert) ClearLogoURL() *PublisherUpsert { return u } +// SetStatus sets the "status" field. +func (u *PublisherUpsert) SetStatus(v schema.PublisherStatusType) *PublisherUpsert { + u.Set(publisher.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *PublisherUpsert) UpdateStatus() *PublisherUpsert { + u.SetExcluded(publisher.FieldStatus) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -704,6 +747,20 @@ func (u *PublisherUpsertOne) ClearLogoURL() *PublisherUpsertOne { }) } +// SetStatus sets the "status" field. +func (u *PublisherUpsertOne) SetStatus(v schema.PublisherStatusType) *PublisherUpsertOne { + return u.Update(func(s *PublisherUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *PublisherUpsertOne) UpdateStatus() *PublisherUpsertOne { + return u.Update(func(s *PublisherUpsert) { + s.UpdateStatus() + }) +} + // Exec executes the query. func (u *PublisherUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1055,6 +1112,20 @@ func (u *PublisherUpsertBulk) ClearLogoURL() *PublisherUpsertBulk { }) } +// SetStatus sets the "status" field. +func (u *PublisherUpsertBulk) SetStatus(v schema.PublisherStatusType) *PublisherUpsertBulk { + return u.Update(func(s *PublisherUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *PublisherUpsertBulk) UpdateStatus() *PublisherUpsertBulk { + return u.Update(func(s *PublisherUpsert) { + s.UpdateStatus() + }) +} + // Exec executes the query. func (u *PublisherUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/ent/publisher_query.go b/ent/publisher_query.go index b3f8c79..fd9f618 100644 --- a/ent/publisher_query.go +++ b/ent/publisher_query.go @@ -703,6 +703,12 @@ func (pq *PublisherQuery) ForShare(opts ...sql.LockOption) *PublisherQuery { return pq } +// Modify adds a query modifier for attaching custom logic to queries. +func (pq *PublisherQuery) Modify(modifiers ...func(s *sql.Selector)) *PublisherSelect { + pq.modifiers = append(pq.modifiers, modifiers...) + return pq.Select() +} + // PublisherGroupBy is the group-by builder for Publisher entities. type PublisherGroupBy struct { selector @@ -792,3 +798,9 @@ func (ps *PublisherSelect) sqlScan(ctx context.Context, root *PublisherQuery, v defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (ps *PublisherSelect) Modify(modifiers ...func(s *sql.Selector)) *PublisherSelect { + ps.modifiers = append(ps.modifiers, modifiers...) + return ps +} diff --git a/ent/publisher_update.go b/ent/publisher_update.go index 652a361..3d86690 100644 --- a/ent/publisher_update.go +++ b/ent/publisher_update.go @@ -11,6 +11,7 @@ import ( "registry-backend/ent/predicate" "registry-backend/ent/publisher" "registry-backend/ent/publisherpermission" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -22,8 +23,9 @@ import ( // PublisherUpdate is the builder for updating Publisher entities. type PublisherUpdate struct { config - hooks []Hook - mutation *PublisherMutation + hooks []Hook + mutation *PublisherMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the PublisherUpdate builder. @@ -152,6 +154,20 @@ func (pu *PublisherUpdate) ClearLogoURL() *PublisherUpdate { return pu } +// SetStatus sets the "status" field. +func (pu *PublisherUpdate) SetStatus(sst schema.PublisherStatusType) *PublisherUpdate { + pu.mutation.SetStatus(sst) + return pu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (pu *PublisherUpdate) SetNillableStatus(sst *schema.PublisherStatusType) *PublisherUpdate { + if sst != nil { + pu.SetStatus(*sst) + } + return pu +} + // AddPublisherPermissionIDs adds the "publisher_permissions" edge to the PublisherPermission entity by IDs. func (pu *PublisherUpdate) AddPublisherPermissionIDs(ids ...int) *PublisherUpdate { pu.mutation.AddPublisherPermissionIDs(ids...) @@ -301,7 +317,26 @@ func (pu *PublisherUpdate) defaults() { } } +// check runs all checks and user-defined validators on the builder. +func (pu *PublisherUpdate) check() error { + if v, ok := pu.mutation.Status(); ok { + if err := publisher.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Publisher.status": %w`, err)} + } + } + return nil +} + +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (pu *PublisherUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *PublisherUpdate { + pu.modifiers = append(pu.modifiers, modifiers...) + return pu +} + func (pu *PublisherUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := pu.check(); err != nil { + return n, err + } _spec := sqlgraph.NewUpdateSpec(publisher.Table, publisher.Columns, sqlgraph.NewFieldSpec(publisher.FieldID, field.TypeString)) if ps := pu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -346,6 +381,9 @@ func (pu *PublisherUpdate) sqlSave(ctx context.Context) (n int, err error) { if pu.mutation.LogoURLCleared() { _spec.ClearField(publisher.FieldLogoURL, field.TypeString) } + if value, ok := pu.mutation.Status(); ok { + _spec.SetField(publisher.FieldStatus, field.TypeEnum, value) + } if pu.mutation.PublisherPermissionsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -481,6 +519,7 @@ func (pu *PublisherUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(pu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, pu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{publisher.Label} @@ -496,9 +535,10 @@ func (pu *PublisherUpdate) sqlSave(ctx context.Context) (n int, err error) { // PublisherUpdateOne is the builder for updating a single Publisher entity. type PublisherUpdateOne struct { config - fields []string - hooks []Hook - mutation *PublisherMutation + fields []string + hooks []Hook + mutation *PublisherMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -621,6 +661,20 @@ func (puo *PublisherUpdateOne) ClearLogoURL() *PublisherUpdateOne { return puo } +// SetStatus sets the "status" field. +func (puo *PublisherUpdateOne) SetStatus(sst schema.PublisherStatusType) *PublisherUpdateOne { + puo.mutation.SetStatus(sst) + return puo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (puo *PublisherUpdateOne) SetNillableStatus(sst *schema.PublisherStatusType) *PublisherUpdateOne { + if sst != nil { + puo.SetStatus(*sst) + } + return puo +} + // AddPublisherPermissionIDs adds the "publisher_permissions" edge to the PublisherPermission entity by IDs. func (puo *PublisherUpdateOne) AddPublisherPermissionIDs(ids ...int) *PublisherUpdateOne { puo.mutation.AddPublisherPermissionIDs(ids...) @@ -783,7 +837,26 @@ func (puo *PublisherUpdateOne) defaults() { } } +// check runs all checks and user-defined validators on the builder. +func (puo *PublisherUpdateOne) check() error { + if v, ok := puo.mutation.Status(); ok { + if err := publisher.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "Publisher.status": %w`, err)} + } + } + return nil +} + +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (puo *PublisherUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *PublisherUpdateOne { + puo.modifiers = append(puo.modifiers, modifiers...) + return puo +} + func (puo *PublisherUpdateOne) sqlSave(ctx context.Context) (_node *Publisher, err error) { + if err := puo.check(); err != nil { + return _node, err + } _spec := sqlgraph.NewUpdateSpec(publisher.Table, publisher.Columns, sqlgraph.NewFieldSpec(publisher.FieldID, field.TypeString)) id, ok := puo.mutation.ID() if !ok { @@ -845,6 +918,9 @@ func (puo *PublisherUpdateOne) sqlSave(ctx context.Context) (_node *Publisher, e if puo.mutation.LogoURLCleared() { _spec.ClearField(publisher.FieldLogoURL, field.TypeString) } + if value, ok := puo.mutation.Status(); ok { + _spec.SetField(publisher.FieldStatus, field.TypeEnum, value) + } if puo.mutation.PublisherPermissionsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -980,6 +1056,7 @@ func (puo *PublisherUpdateOne) sqlSave(ctx context.Context) (_node *Publisher, e } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(puo.modifiers...) _node = &Publisher{config: puo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/publisherpermission_query.go b/ent/publisherpermission_query.go index 00e51d0..0f767cf 100644 --- a/ent/publisherpermission_query.go +++ b/ent/publisherpermission_query.go @@ -626,6 +626,12 @@ func (ppq *PublisherPermissionQuery) ForShare(opts ...sql.LockOption) *Publisher return ppq } +// Modify adds a query modifier for attaching custom logic to queries. +func (ppq *PublisherPermissionQuery) Modify(modifiers ...func(s *sql.Selector)) *PublisherPermissionSelect { + ppq.modifiers = append(ppq.modifiers, modifiers...) + return ppq.Select() +} + // PublisherPermissionGroupBy is the group-by builder for PublisherPermission entities. type PublisherPermissionGroupBy struct { selector @@ -715,3 +721,9 @@ func (pps *PublisherPermissionSelect) sqlScan(ctx context.Context, root *Publish defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (pps *PublisherPermissionSelect) Modify(modifiers ...func(s *sql.Selector)) *PublisherPermissionSelect { + pps.modifiers = append(pps.modifiers, modifiers...) + return pps +} diff --git a/ent/publisherpermission_update.go b/ent/publisherpermission_update.go index 5610f42..ac1da53 100644 --- a/ent/publisherpermission_update.go +++ b/ent/publisherpermission_update.go @@ -20,8 +20,9 @@ import ( // PublisherPermissionUpdate is the builder for updating PublisherPermission entities. type PublisherPermissionUpdate struct { config - hooks []Hook - mutation *PublisherPermissionMutation + hooks []Hook + mutation *PublisherPermissionMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the PublisherPermissionUpdate builder. @@ -142,6 +143,12 @@ func (ppu *PublisherPermissionUpdate) check() error { return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (ppu *PublisherPermissionUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *PublisherPermissionUpdate { + ppu.modifiers = append(ppu.modifiers, modifiers...) + return ppu +} + func (ppu *PublisherPermissionUpdate) sqlSave(ctx context.Context) (n int, err error) { if err := ppu.check(); err != nil { return n, err @@ -215,6 +222,7 @@ func (ppu *PublisherPermissionUpdate) sqlSave(ctx context.Context) (n int, err e } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(ppu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, ppu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{publisherpermission.Label} @@ -230,9 +238,10 @@ func (ppu *PublisherPermissionUpdate) sqlSave(ctx context.Context) (n int, err e // PublisherPermissionUpdateOne is the builder for updating a single PublisherPermission entity. type PublisherPermissionUpdateOne struct { config - fields []string - hooks []Hook - mutation *PublisherPermissionMutation + fields []string + hooks []Hook + mutation *PublisherPermissionMutation + modifiers []func(*sql.UpdateBuilder) } // SetPermission sets the "permission" field. @@ -360,6 +369,12 @@ func (ppuo *PublisherPermissionUpdateOne) check() error { return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (ppuo *PublisherPermissionUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *PublisherPermissionUpdateOne { + ppuo.modifiers = append(ppuo.modifiers, modifiers...) + return ppuo +} + func (ppuo *PublisherPermissionUpdateOne) sqlSave(ctx context.Context) (_node *PublisherPermission, err error) { if err := ppuo.check(); err != nil { return _node, err @@ -450,6 +465,7 @@ func (ppuo *PublisherPermissionUpdateOne) sqlSave(ctx context.Context) (_node *P } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + _spec.AddModifiers(ppuo.modifiers...) _node = &PublisherPermission{config: ppuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/runtime.go b/ent/runtime.go index 63e4103..69c967e 100644 --- a/ent/runtime.go +++ b/ent/runtime.go @@ -6,6 +6,7 @@ import ( "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" "registry-backend/ent/node" + "registry-backend/ent/nodereview" "registry-backend/ent/nodeversion" "registry-backend/ent/personalaccesstoken" "registry-backend/ent/publisher" @@ -36,6 +37,10 @@ func init() { ciworkflowresult.DefaultUpdateTime = ciworkflowresultDescUpdateTime.Default.(func() time.Time) // ciworkflowresult.UpdateDefaultUpdateTime holds the default value on update for the update_time field. ciworkflowresult.UpdateDefaultUpdateTime = ciworkflowresultDescUpdateTime.UpdateDefault.(func() time.Time) + // ciworkflowresultDescStatus is the schema descriptor for status field. + ciworkflowresultDescStatus := ciworkflowresultFields[5].Descriptor() + // ciworkflowresult.DefaultStatus holds the default value on creation for the status field. + ciworkflowresult.DefaultStatus = schema.WorkflowRunStatusType(ciworkflowresultDescStatus.Default.(string)) // ciworkflowresultDescID is the schema descriptor for id field. ciworkflowresultDescID := ciworkflowresultFields[0].Descriptor() // ciworkflowresult.DefaultID holds the default value on creation for the id field. @@ -75,9 +80,31 @@ func init() { // node.UpdateDefaultUpdateTime holds the default value on update for the update_time field. node.UpdateDefaultUpdateTime = nodeDescUpdateTime.UpdateDefault.(func() time.Time) // nodeDescTags is the schema descriptor for tags field. - nodeDescTags := nodeFields[8].Descriptor() + nodeDescTags := nodeFields[9].Descriptor() // node.DefaultTags holds the default value on creation for the tags field. node.DefaultTags = nodeDescTags.Default.([]string) + // nodeDescTotalInstall is the schema descriptor for total_install field. + nodeDescTotalInstall := nodeFields[10].Descriptor() + // node.DefaultTotalInstall holds the default value on creation for the total_install field. + node.DefaultTotalInstall = nodeDescTotalInstall.Default.(int64) + // nodeDescTotalStar is the schema descriptor for total_star field. + nodeDescTotalStar := nodeFields[11].Descriptor() + // node.DefaultTotalStar holds the default value on creation for the total_star field. + node.DefaultTotalStar = nodeDescTotalStar.Default.(int64) + // nodeDescTotalReview is the schema descriptor for total_review field. + nodeDescTotalReview := nodeFields[12].Descriptor() + // node.DefaultTotalReview holds the default value on creation for the total_review field. + node.DefaultTotalReview = nodeDescTotalReview.Default.(int64) + nodereviewFields := schema.NodeReview{}.Fields() + _ = nodereviewFields + // nodereviewDescStar is the schema descriptor for star field. + nodereviewDescStar := nodereviewFields[3].Descriptor() + // nodereview.DefaultStar holds the default value on creation for the star field. + nodereview.DefaultStar = nodereviewDescStar.Default.(int) + // nodereviewDescID is the schema descriptor for id field. + nodereviewDescID := nodereviewFields[0].Descriptor() + // nodereview.DefaultID holds the default value on creation for the id field. + nodereview.DefaultID = nodereviewDescID.Default.(func() uuid.UUID) nodeversionMixin := schema.NodeVersion{}.Mixin() nodeversionMixinFields0 := nodeversionMixin[0].Fields() _ = nodeversionMixinFields0 @@ -97,6 +124,10 @@ func init() { nodeversionDescDeprecated := nodeversionFields[5].Descriptor() // nodeversion.DefaultDeprecated holds the default value on creation for the deprecated field. nodeversion.DefaultDeprecated = nodeversionDescDeprecated.Default.(bool) + // nodeversionDescStatusReason is the schema descriptor for status_reason field. + nodeversionDescStatusReason := nodeversionFields[7].Descriptor() + // nodeversion.DefaultStatusReason holds the default value on creation for the status_reason field. + nodeversion.DefaultStatusReason = nodeversionDescStatusReason.Default.(string) // nodeversionDescID is the schema descriptor for id field. nodeversionDescID := nodeversionFields[0].Descriptor() // nodeversion.DefaultID holds the default value on creation for the id field. diff --git a/ent/schema/ci_workflow_result.go b/ent/schema/ci_workflow_result.go index 3bd3f96..86f3bf3 100644 --- a/ent/schema/ci_workflow_result.go +++ b/ent/schema/ci_workflow_result.go @@ -9,7 +9,7 @@ import ( "github.com/google/uuid" ) -// CIWorkflowResult holds the schema definition for the CIWorkflowResult entity. +// Stores the artifacts and metadata about a single workflow execution for Comfy CI. type CIWorkflowResult struct { ent.Schema } @@ -21,31 +21,54 @@ func (CIWorkflowResult) Fields() []ent.Field { field.String("operating_system").SchemaType(map[string]string{ dialect.Postgres: "text", }), - field.String("gpu_type").SchemaType(map[string]string{ - dialect.Postgres: "text", - }).Optional(), - field.String("pytorch_version").SchemaType(map[string]string{ - dialect.Postgres: "text", - }).Optional(), field.String("workflow_name").SchemaType(map[string]string{ dialect.Postgres: "text", }).Optional(), field.String("run_id").SchemaType(map[string]string{ dialect.Postgres: "text", }).Optional(), - field.String("status").SchemaType(map[string]string{ + field.String("job_id").SchemaType(map[string]string{ dialect.Postgres: "text", }).Optional(), + field.String("status").GoType(WorkflowRunStatusType("")).Default(string(WorkflowRunStatusTypeStarted)), field.Int64("start_time").Optional(), field.Int64("end_time").Optional(), + field.String("python_version").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), + field.String("pytorch_version").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), + field.String("cuda_version").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), + field.String("comfy_run_flags").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), + field.Int("avg_vram").Optional().Comment("Average amount of VRAM used by the workflow in Megabytes"), + field.Int("peak_vram").Optional().Comment("Peak amount of VRAM used by the workflow in Megabytes"), + field.String("job_trigger_user").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional().Comment("User who triggered the job"), + field.JSON("metadata", map[string]interface{}{}).SchemaType(map[string]string{ + dialect.Postgres: "jsonb", + }).Optional().Comment("Stores miscellaneous metadata for each workflow run."), } } +type WorkflowRunStatusType string + +const ( + WorkflowRunStatusTypeCompleted WorkflowRunStatusType = "COMPLETED" + WorkflowRunStatusTypeFailed WorkflowRunStatusType = "FAILED" + WorkflowRunStatusTypeStarted WorkflowRunStatusType = "STARTED" +) + // Edges of the CIWorkflowResult. func (CIWorkflowResult) Edges() []ent.Edge { return []ent.Edge{ edge.From("gitcommit", GitCommit.Type).Ref("results").Unique(), - edge.To("storage_file", StorageFile.Type).Unique(), + edge.To("storage_file", StorageFile.Type), } } diff --git a/ent/schema/git_commit.go b/ent/schema/git_commit.go index 6bac90c..391691e 100644 --- a/ent/schema/git_commit.go +++ b/ent/schema/git_commit.go @@ -36,6 +36,9 @@ func (GitCommit) Fields() []ent.Field { dialect.Postgres: "text", }).Optional(), field.Time("timestamp").Optional(), + field.String("pr_number").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), } } diff --git a/ent/schema/node.go b/ent/schema/node.go index 43efc2b..eff4878 100644 --- a/ent/schema/node.go +++ b/ent/schema/node.go @@ -28,6 +28,9 @@ func (Node) Fields() []ent.Field { field.String("description").Optional().SchemaType(map[string]string{ dialect.Postgres: "text", }), + field.String("category").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), field.String("author").SchemaType(map[string]string{ dialect.Postgres: "text", }).Optional(), @@ -43,6 +46,15 @@ func (Node) Fields() []ent.Field { field.Strings("tags").SchemaType(map[string]string{ dialect.Postgres: "text", }).Default([]string{}), + field.Int64("total_install").Default(0), + field.Int64("total_star").Default(0), + field.Int64("total_review").Default(0), + field.Enum("status"). + GoType(NodeStatus("")). + Default(string(NodeStatusActive)), + field.String("status_detail").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Optional(), } } @@ -56,5 +68,22 @@ func (Node) Edges() []ent.Edge { return []ent.Edge{ edge.From("publisher", Publisher.Type).Field("publisher_id").Ref("nodes").Required().Unique(), edge.To("versions", NodeVersion.Type), + edge.To("reviews", NodeReview.Type), + } +} + +type NodeStatus string + +const ( + NodeStatusActive NodeStatus = "active" + NodeStatusDeleted NodeStatus = "deleted" + NodeStatusBanned NodeStatus = "banned" +) + +func (NodeStatus) Values() (types []string) { + return []string{ + string(NodeStatusActive), + string(NodeStatusBanned), + string(NodeStatusDeleted), } } diff --git a/ent/schema/node_review.go b/ent/schema/node_review.go new file mode 100644 index 0000000..50224d1 --- /dev/null +++ b/ent/schema/node_review.go @@ -0,0 +1,39 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/schema/edge" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" + "github.com/google/uuid" +) + +type NodeReview struct { + ent.Schema +} + +func (NodeReview) Fields() []ent.Field { + return []ent.Field{ + field.UUID("id", uuid.UUID{}).Default(uuid.New), + field.String("node_id").SchemaType(map[string]string{ + dialect.Postgres: "text", + }), + field.String("user_id").SchemaType(map[string]string{ + dialect.Postgres: "text", + }), + field.Int("star").Default(0), + } +} + +func (NodeReview) Edges() []ent.Edge { + return []ent.Edge{ + edge.From("user", User.Type).Field("user_id").Ref("reviews").Unique().Required(), + edge.From("node", Node.Type).Field("node_id").Ref("reviews").Unique().Required(), + } +} +func (NodeReview) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("node_id", "user_id").Unique(), + } +} diff --git a/ent/schema/node_version.go b/ent/schema/node_version.go index 26e95bd..40c8581 100644 --- a/ent/schema/node_version.go +++ b/ent/schema/node_version.go @@ -32,6 +32,12 @@ func (NodeVersion) Fields() []ent.Field { dialect.Postgres: "text", }), field.Bool("deprecated").Default(false), + field.Enum("status"). + GoType(NodeVersionStatus("")). + Default(string(NodeVersionStatusPending)), + field.String("status_reason").SchemaType(map[string]string{ + dialect.Postgres: "text", + }).Default("").Comment("Give a reason for the status change. Eg. 'Banned due to security vulnerability'"), } } @@ -53,3 +59,23 @@ func (NodeVersion) Indexes() []ent.Index { index.Fields("node_id", "version").Unique(), } } + +type NodeVersionStatus string + +const ( + NodeVersionStatusActive NodeVersionStatus = "active" + NodeVersionStatusDeleted NodeVersionStatus = "deleted" + NodeVersionStatusBanned NodeVersionStatus = "banned" + NodeVersionStatusPending NodeVersionStatus = "pending" + NodeVersionStatusFlagged NodeVersionStatus = "flagged" +) + +func (NodeVersionStatus) Values() (types []string) { + return []string{ + string(NodeVersionStatusActive), + string(NodeVersionStatusBanned), + string(NodeVersionStatusDeleted), + string(NodeVersionStatusPending), + string(NodeVersionStatusFlagged), + } +} diff --git a/ent/schema/publisher.go b/ent/schema/publisher.go index 95c4e98..606e676 100644 --- a/ent/schema/publisher.go +++ b/ent/schema/publisher.go @@ -36,6 +36,7 @@ func (Publisher) Fields() []ent.Field { field.String("logo_url").Optional().SchemaType(map[string]string{ dialect.Postgres: "text", }).Optional(), + field.Enum("status").GoType(PublisherStatusType("")).Default(string(PublisherStatusTypeActive)), } } @@ -52,3 +53,17 @@ func (Publisher) Edges() []ent.Edge { edge.To("personal_access_tokens", PersonalAccessToken.Type), } } + +type PublisherStatusType string + +const ( + PublisherStatusTypeActive PublisherStatusType = "ACTIVE" + PublisherStatusTypeBanned PublisherStatusType = "BANNED" +) + +func (PublisherStatusType) Values() (types []string) { + return []string{ + string(PublisherStatusTypeActive), + string(PublisherStatusTypeBanned), + } +} diff --git a/ent/schema/user.go b/ent/schema/user.go index 67ec683..7818a53 100644 --- a/ent/schema/user.go +++ b/ent/schema/user.go @@ -20,6 +20,7 @@ func (User) Fields() []ent.Field { field.String("name").Optional(), field.Bool("is_approved").Default(false).Comment("Whether the user is approved to use the platform"), field.Bool("is_admin").Default(false).Comment("Whether the user is approved to use the platform"), + field.Enum("status").GoType(UserStatusType("")).Default(string(UserStatusTypeActive)), } } @@ -33,5 +34,20 @@ func (User) Mixin() []ent.Mixin { func (User) Edges() []ent.Edge { return []ent.Edge{ edge.To("publisher_permissions", PublisherPermission.Type), + edge.To("reviews", NodeReview.Type), + } +} + +type UserStatusType string + +const ( + UserStatusTypeActive UserStatusType = "ACTIVE" + UserStatusTypeBanned UserStatusType = "BANNED" +) + +func (UserStatusType) Values() (types []string) { + return []string{ + string(UserStatusTypeActive), + string(UserStatusTypeBanned), } } diff --git a/ent/storagefile.go b/ent/storagefile.go index df85918..6b121e2 100644 --- a/ent/storagefile.go +++ b/ent/storagefile.go @@ -31,8 +31,9 @@ type StorageFile struct { // e.g., image, video FileType string `json:"file_type,omitempty"` // Publicly accessible URL of the file, if available - FileURL string `json:"file_url,omitempty"` - selectValues sql.SelectValues + FileURL string `json:"file_url,omitempty"` + ci_workflow_result_storage_file *uuid.UUID + selectValues sql.SelectValues } // scanValues returns the types for scanning values from sql.Rows. @@ -46,6 +47,8 @@ func (*StorageFile) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullTime) case storagefile.FieldID: values[i] = new(uuid.UUID) + case storagefile.ForeignKeys[0]: // ci_workflow_result_storage_file + values[i] = &sql.NullScanner{S: new(uuid.UUID)} default: values[i] = new(sql.UnknownType) } @@ -109,6 +112,13 @@ func (sf *StorageFile) assignValues(columns []string, values []any) error { } else if value.Valid { sf.FileURL = value.String } + case storagefile.ForeignKeys[0]: + if value, ok := values[i].(*sql.NullScanner); !ok { + return fmt.Errorf("unexpected type %T for field ci_workflow_result_storage_file", values[i]) + } else if value.Valid { + sf.ci_workflow_result_storage_file = new(uuid.UUID) + *sf.ci_workflow_result_storage_file = *value.S.(*uuid.UUID) + } default: sf.selectValues.Set(columns[i], values[i]) } diff --git a/ent/storagefile/storagefile.go b/ent/storagefile/storagefile.go index c73252c..30db172 100644 --- a/ent/storagefile/storagefile.go +++ b/ent/storagefile/storagefile.go @@ -44,6 +44,12 @@ var Columns = []string{ FieldFileURL, } +// ForeignKeys holds the SQL foreign-keys that are owned by the "storage_files" +// table and are not defined as standalone fields in the schema. +var ForeignKeys = []string{ + "ci_workflow_result_storage_file", +} + // ValidColumn reports if the column name is valid (part of the table columns). func ValidColumn(column string) bool { for i := range Columns { @@ -51,6 +57,11 @@ func ValidColumn(column string) bool { return true } } + for i := range ForeignKeys { + if column == ForeignKeys[i] { + return true + } + } return false } diff --git a/ent/storagefile_query.go b/ent/storagefile_query.go index f2fe14e..3d014aa 100644 --- a/ent/storagefile_query.go +++ b/ent/storagefile_query.go @@ -23,6 +23,7 @@ type StorageFileQuery struct { order []storagefile.OrderOption inters []Interceptor predicates []predicate.StorageFile + withFKs bool modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector @@ -334,9 +335,13 @@ func (sfq *StorageFileQuery) prepareQuery(ctx context.Context) error { func (sfq *StorageFileQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*StorageFile, error) { var ( - nodes = []*StorageFile{} - _spec = sfq.querySpec() + nodes = []*StorageFile{} + withFKs = sfq.withFKs + _spec = sfq.querySpec() ) + if withFKs { + _spec.Node.Columns = append(_spec.Node.Columns, storagefile.ForeignKeys...) + } _spec.ScanValues = func(columns []string) ([]any, error) { return (*StorageFile).scanValues(nil, columns) } @@ -473,6 +478,12 @@ func (sfq *StorageFileQuery) ForShare(opts ...sql.LockOption) *StorageFileQuery return sfq } +// Modify adds a query modifier for attaching custom logic to queries. +func (sfq *StorageFileQuery) Modify(modifiers ...func(s *sql.Selector)) *StorageFileSelect { + sfq.modifiers = append(sfq.modifiers, modifiers...) + return sfq.Select() +} + // StorageFileGroupBy is the group-by builder for StorageFile entities. type StorageFileGroupBy struct { selector @@ -562,3 +573,9 @@ func (sfs *StorageFileSelect) sqlScan(ctx context.Context, root *StorageFileQuer defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (sfs *StorageFileSelect) Modify(modifiers ...func(s *sql.Selector)) *StorageFileSelect { + sfs.modifiers = append(sfs.modifiers, modifiers...) + return sfs +} diff --git a/ent/storagefile_update.go b/ent/storagefile_update.go index 26bebda..2ed14c7 100644 --- a/ent/storagefile_update.go +++ b/ent/storagefile_update.go @@ -18,8 +18,9 @@ import ( // StorageFileUpdate is the builder for updating StorageFile entities. type StorageFileUpdate struct { config - hooks []Hook - mutation *StorageFileMutation + hooks []Hook + mutation *StorageFileMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the StorageFileUpdate builder. @@ -182,6 +183,12 @@ func (sfu *StorageFileUpdate) check() error { return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (sfu *StorageFileUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *StorageFileUpdate { + sfu.modifiers = append(sfu.modifiers, modifiers...) + return sfu +} + func (sfu *StorageFileUpdate) sqlSave(ctx context.Context) (n int, err error) { if err := sfu.check(); err != nil { return n, err @@ -218,6 +225,7 @@ func (sfu *StorageFileUpdate) sqlSave(ctx context.Context) (n int, err error) { if sfu.mutation.FileURLCleared() { _spec.ClearField(storagefile.FieldFileURL, field.TypeString) } + _spec.AddModifiers(sfu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, sfu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{storagefile.Label} @@ -233,9 +241,10 @@ func (sfu *StorageFileUpdate) sqlSave(ctx context.Context) (n int, err error) { // StorageFileUpdateOne is the builder for updating a single StorageFile entity. type StorageFileUpdateOne struct { config - fields []string - hooks []Hook - mutation *StorageFileMutation + fields []string + hooks []Hook + mutation *StorageFileMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -405,6 +414,12 @@ func (sfuo *StorageFileUpdateOne) check() error { return nil } +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (sfuo *StorageFileUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *StorageFileUpdateOne { + sfuo.modifiers = append(sfuo.modifiers, modifiers...) + return sfuo +} + func (sfuo *StorageFileUpdateOne) sqlSave(ctx context.Context) (_node *StorageFile, err error) { if err := sfuo.check(); err != nil { return _node, err @@ -458,6 +473,7 @@ func (sfuo *StorageFileUpdateOne) sqlSave(ctx context.Context) (_node *StorageFi if sfuo.mutation.FileURLCleared() { _spec.ClearField(storagefile.FieldFileURL, field.TypeString) } + _spec.AddModifiers(sfuo.modifiers...) _node = &StorageFile{config: sfuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/ent/tx.go b/ent/tx.go index 2fae588..cfef450 100644 --- a/ent/tx.go +++ b/ent/tx.go @@ -18,6 +18,8 @@ type Tx struct { GitCommit *GitCommitClient // Node is the client for interacting with the Node builders. Node *NodeClient + // NodeReview is the client for interacting with the NodeReview builders. + NodeReview *NodeReviewClient // NodeVersion is the client for interacting with the NodeVersion builders. NodeVersion *NodeVersionClient // PersonalAccessToken is the client for interacting with the PersonalAccessToken builders. @@ -164,6 +166,7 @@ func (tx *Tx) init() { tx.CIWorkflowResult = NewCIWorkflowResultClient(tx.config) tx.GitCommit = NewGitCommitClient(tx.config) tx.Node = NewNodeClient(tx.config) + tx.NodeReview = NewNodeReviewClient(tx.config) tx.NodeVersion = NewNodeVersionClient(tx.config) tx.PersonalAccessToken = NewPersonalAccessTokenClient(tx.config) tx.Publisher = NewPublisherClient(tx.config) diff --git a/ent/user.go b/ent/user.go index 23b524f..9c67679 100644 --- a/ent/user.go +++ b/ent/user.go @@ -4,6 +4,7 @@ package ent import ( "fmt" + "registry-backend/ent/schema" "registry-backend/ent/user" "strings" "time" @@ -30,6 +31,8 @@ type User struct { IsApproved bool `json:"is_approved,omitempty"` // Whether the user is approved to use the platform IsAdmin bool `json:"is_admin,omitempty"` + // Status holds the value of the "status" field. + Status schema.UserStatusType `json:"status,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -40,9 +43,11 @@ type User struct { type UserEdges struct { // PublisherPermissions holds the value of the publisher_permissions edge. PublisherPermissions []*PublisherPermission `json:"publisher_permissions,omitempty"` + // Reviews holds the value of the reviews edge. + Reviews []*NodeReview `json:"reviews,omitempty"` // loadedTypes holds the information for reporting if a // type was loaded (or requested) in eager-loading or not. - loadedTypes [1]bool + loadedTypes [2]bool } // PublisherPermissionsOrErr returns the PublisherPermissions value or an error if the edge @@ -54,6 +59,15 @@ func (e UserEdges) PublisherPermissionsOrErr() ([]*PublisherPermission, error) { return nil, &NotLoadedError{edge: "publisher_permissions"} } +// ReviewsOrErr returns the Reviews value or an error if the edge +// was not loaded in eager-loading. +func (e UserEdges) ReviewsOrErr() ([]*NodeReview, error) { + if e.loadedTypes[1] { + return e.Reviews, nil + } + return nil, &NotLoadedError{edge: "reviews"} +} + // scanValues returns the types for scanning values from sql.Rows. func (*User) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) @@ -61,7 +75,7 @@ func (*User) scanValues(columns []string) ([]any, error) { switch columns[i] { case user.FieldIsApproved, user.FieldIsAdmin: values[i] = new(sql.NullBool) - case user.FieldID, user.FieldEmail, user.FieldName: + case user.FieldID, user.FieldEmail, user.FieldName, user.FieldStatus: values[i] = new(sql.NullString) case user.FieldCreateTime, user.FieldUpdateTime: values[i] = new(sql.NullTime) @@ -122,6 +136,12 @@ func (u *User) assignValues(columns []string, values []any) error { } else if value.Valid { u.IsAdmin = value.Bool } + case user.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + u.Status = schema.UserStatusType(value.String) + } default: u.selectValues.Set(columns[i], values[i]) } @@ -140,6 +160,11 @@ func (u *User) QueryPublisherPermissions() *PublisherPermissionQuery { return NewUserClient(u.config).QueryPublisherPermissions(u) } +// QueryReviews queries the "reviews" edge of the User entity. +func (u *User) QueryReviews() *NodeReviewQuery { + return NewUserClient(u.config).QueryReviews(u) +} + // Update returns a builder for updating this User. // Note that you need to call User.Unwrap() before calling this method if this User // was returned from a transaction, and the transaction was committed or rolled back. @@ -180,6 +205,9 @@ func (u *User) String() string { builder.WriteString(", ") builder.WriteString("is_admin=") builder.WriteString(fmt.Sprintf("%v", u.IsAdmin)) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(fmt.Sprintf("%v", u.Status)) builder.WriteByte(')') return builder.String() } diff --git a/ent/user/user.go b/ent/user/user.go index 656c8fb..e9e0072 100644 --- a/ent/user/user.go +++ b/ent/user/user.go @@ -3,6 +3,8 @@ package user import ( + "fmt" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -26,8 +28,12 @@ const ( FieldIsApproved = "is_approved" // FieldIsAdmin holds the string denoting the is_admin field in the database. FieldIsAdmin = "is_admin" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" // EdgePublisherPermissions holds the string denoting the publisher_permissions edge name in mutations. EdgePublisherPermissions = "publisher_permissions" + // EdgeReviews holds the string denoting the reviews edge name in mutations. + EdgeReviews = "reviews" // Table holds the table name of the user in the database. Table = "users" // PublisherPermissionsTable is the table that holds the publisher_permissions relation/edge. @@ -37,6 +43,13 @@ const ( PublisherPermissionsInverseTable = "publisher_permissions" // PublisherPermissionsColumn is the table column denoting the publisher_permissions relation/edge. PublisherPermissionsColumn = "user_id" + // ReviewsTable is the table that holds the reviews relation/edge. + ReviewsTable = "node_reviews" + // ReviewsInverseTable is the table name for the NodeReview entity. + // It exists in this package in order to avoid circular dependency with the "nodereview" package. + ReviewsInverseTable = "node_reviews" + // ReviewsColumn is the table column denoting the reviews relation/edge. + ReviewsColumn = "user_id" ) // Columns holds all SQL columns for user fields. @@ -48,6 +61,7 @@ var Columns = []string{ FieldName, FieldIsApproved, FieldIsAdmin, + FieldStatus, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -73,6 +87,18 @@ var ( DefaultIsAdmin bool ) +const DefaultStatus schema.UserStatusType = "ACTIVE" + +// StatusValidator is a validator for the "status" field enum values. It is called by the builders before save. +func StatusValidator(s schema.UserStatusType) error { + switch s { + case "ACTIVE", "BANNED": + return nil + default: + return fmt.Errorf("user: invalid enum value for status field: %q", s) + } +} + // OrderOption defines the ordering options for the User queries. type OrderOption func(*sql.Selector) @@ -111,6 +137,11 @@ func ByIsAdmin(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldIsAdmin, opts...).ToFunc() } +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + // ByPublisherPermissionsCount orders the results by publisher_permissions count. func ByPublisherPermissionsCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { @@ -124,6 +155,20 @@ func ByPublisherPermissions(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOpt sqlgraph.OrderByNeighborTerms(s, newPublisherPermissionsStep(), append([]sql.OrderTerm{term}, terms...)...) } } + +// ByReviewsCount orders the results by reviews count. +func ByReviewsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newReviewsStep(), opts...) + } +} + +// ByReviews orders the results by reviews terms. +func ByReviews(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newReviewsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} func newPublisherPermissionsStep() *sqlgraph.Step { return sqlgraph.NewStep( sqlgraph.From(Table, FieldID), @@ -131,3 +176,10 @@ func newPublisherPermissionsStep() *sqlgraph.Step { sqlgraph.Edge(sqlgraph.O2M, false, PublisherPermissionsTable, PublisherPermissionsColumn), ) } +func newReviewsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(ReviewsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReviewsTable, ReviewsColumn), + ) +} diff --git a/ent/user/where.go b/ent/user/where.go index 0471803..b2b3f7b 100644 --- a/ent/user/where.go +++ b/ent/user/where.go @@ -4,6 +4,7 @@ package user import ( "registry-backend/ent/predicate" + "registry-backend/ent/schema" "time" "entgo.io/ent/dialect/sql" @@ -345,6 +346,36 @@ func IsAdminNEQ(v bool) predicate.User { return predicate.User(sql.FieldNEQ(FieldIsAdmin, v)) } +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v schema.UserStatusType) predicate.User { + vc := v + return predicate.User(sql.FieldEQ(FieldStatus, vc)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v schema.UserStatusType) predicate.User { + vc := v + return predicate.User(sql.FieldNEQ(FieldStatus, vc)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...schema.UserStatusType) predicate.User { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(sql.FieldIn(FieldStatus, v...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...schema.UserStatusType) predicate.User { + v := make([]any, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.User(sql.FieldNotIn(FieldStatus, v...)) +} + // HasPublisherPermissions applies the HasEdge predicate on the "publisher_permissions" edge. func HasPublisherPermissions() predicate.User { return predicate.User(func(s *sql.Selector) { @@ -368,6 +399,29 @@ func HasPublisherPermissionsWith(preds ...predicate.PublisherPermission) predica }) } +// HasReviews applies the HasEdge predicate on the "reviews" edge. +func HasReviews() predicate.User { + return predicate.User(func(s *sql.Selector) { + step := sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, ReviewsTable, ReviewsColumn), + ) + sqlgraph.HasNeighbors(s, step) + }) +} + +// HasReviewsWith applies the HasEdge predicate on the "reviews" edge with a given conditions (other predicates). +func HasReviewsWith(preds ...predicate.NodeReview) predicate.User { + return predicate.User(func(s *sql.Selector) { + step := newReviewsStep() + sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { + for _, p := range preds { + p(s) + } + }) + }) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.User) predicate.User { return predicate.User(sql.AndPredicates(predicates...)) diff --git a/ent/user_create.go b/ent/user_create.go index 7f669c3..cebcc73 100644 --- a/ent/user_create.go +++ b/ent/user_create.go @@ -6,7 +6,9 @@ import ( "context" "errors" "fmt" + "registry-backend/ent/nodereview" "registry-backend/ent/publisherpermission" + "registry-backend/ent/schema" "registry-backend/ent/user" "time" @@ -14,6 +16,7 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/google/uuid" ) // UserCreate is the builder for creating a User entity. @@ -108,6 +111,20 @@ func (uc *UserCreate) SetNillableIsAdmin(b *bool) *UserCreate { return uc } +// SetStatus sets the "status" field. +func (uc *UserCreate) SetStatus(sst schema.UserStatusType) *UserCreate { + uc.mutation.SetStatus(sst) + return uc +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (uc *UserCreate) SetNillableStatus(sst *schema.UserStatusType) *UserCreate { + if sst != nil { + uc.SetStatus(*sst) + } + return uc +} + // SetID sets the "id" field. func (uc *UserCreate) SetID(s string) *UserCreate { uc.mutation.SetID(s) @@ -129,6 +146,21 @@ func (uc *UserCreate) AddPublisherPermissions(p ...*PublisherPermission) *UserCr return uc.AddPublisherPermissionIDs(ids...) } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by IDs. +func (uc *UserCreate) AddReviewIDs(ids ...uuid.UUID) *UserCreate { + uc.mutation.AddReviewIDs(ids...) + return uc +} + +// AddReviews adds the "reviews" edges to the NodeReview entity. +func (uc *UserCreate) AddReviews(n ...*NodeReview) *UserCreate { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return uc.AddReviewIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (uc *UserCreate) Mutation() *UserMutation { return uc.mutation @@ -180,6 +212,10 @@ func (uc *UserCreate) defaults() { v := user.DefaultIsAdmin uc.mutation.SetIsAdmin(v) } + if _, ok := uc.mutation.Status(); !ok { + v := user.DefaultStatus + uc.mutation.SetStatus(v) + } } // check runs all checks and user-defined validators on the builder. @@ -196,6 +232,14 @@ func (uc *UserCreate) check() error { if _, ok := uc.mutation.IsAdmin(); !ok { return &ValidationError{Name: "is_admin", err: errors.New(`ent: missing required field "User.is_admin"`)} } + if _, ok := uc.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "User.status"`)} + } + if v, ok := uc.mutation.Status(); ok { + if err := user.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "User.status": %w`, err)} + } + } return nil } @@ -256,6 +300,10 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldIsAdmin, field.TypeBool, value) _node.IsAdmin = value } + if value, ok := uc.mutation.Status(); ok { + _spec.SetField(user.FieldStatus, field.TypeEnum, value) + _node.Status = value + } if nodes := uc.mutation.PublisherPermissionsIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -272,6 +320,22 @@ func (uc *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { } _spec.Edges = append(_spec.Edges, edge) } + if nodes := uc.mutation.ReviewsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges = append(_spec.Edges, edge) + } return _node, _spec } @@ -396,6 +460,18 @@ func (u *UserUpsert) UpdateIsAdmin() *UserUpsert { return u } +// SetStatus sets the "status" field. +func (u *UserUpsert) SetStatus(v schema.UserStatusType) *UserUpsert { + u.Set(user.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UserUpsert) UpdateStatus() *UserUpsert { + u.SetExcluded(user.FieldStatus) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create except the ID field. // Using this option is equivalent to using: // @@ -531,6 +607,20 @@ func (u *UserUpsertOne) UpdateIsAdmin() *UserUpsertOne { }) } +// SetStatus sets the "status" field. +func (u *UserUpsertOne) SetStatus(v schema.UserStatusType) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateStatus() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateStatus() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -833,6 +923,20 @@ func (u *UserUpsertBulk) UpdateIsAdmin() *UserUpsertBulk { }) } +// SetStatus sets the "status" field. +func (u *UserUpsertBulk) SetStatus(v schema.UserStatusType) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateStatus() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateStatus() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/ent/user_query.go b/ent/user_query.go index 3c42ebc..a1548b7 100644 --- a/ent/user_query.go +++ b/ent/user_query.go @@ -7,6 +7,7 @@ import ( "database/sql/driver" "fmt" "math" + "registry-backend/ent/nodereview" "registry-backend/ent/predicate" "registry-backend/ent/publisherpermission" "registry-backend/ent/user" @@ -25,6 +26,7 @@ type UserQuery struct { inters []Interceptor predicates []predicate.User withPublisherPermissions *PublisherPermissionQuery + withReviews *NodeReviewQuery modifiers []func(*sql.Selector) // intermediate query (i.e. traversal path). sql *sql.Selector @@ -84,6 +86,28 @@ func (uq *UserQuery) QueryPublisherPermissions() *PublisherPermissionQuery { return query } +// QueryReviews chains the current query on the "reviews" edge. +func (uq *UserQuery) QueryReviews() *NodeReviewQuery { + query := (&NodeReviewClient{config: uq.config}).Query() + query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { + if err := uq.prepareQuery(ctx); err != nil { + return nil, err + } + selector := uq.sqlQuery(ctx) + if err := selector.Err(); err != nil { + return nil, err + } + step := sqlgraph.NewStep( + sqlgraph.From(user.Table, user.FieldID, selector), + sqlgraph.To(nodereview.Table, nodereview.FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, user.ReviewsTable, user.ReviewsColumn), + ) + fromU = sqlgraph.SetNeighbors(uq.driver.Dialect(), step) + return fromU, nil + } + return query +} + // First returns the first User entity from the query. // Returns a *NotFoundError when no User was found. func (uq *UserQuery) First(ctx context.Context) (*User, error) { @@ -277,6 +301,7 @@ func (uq *UserQuery) Clone() *UserQuery { inters: append([]Interceptor{}, uq.inters...), predicates: append([]predicate.User{}, uq.predicates...), withPublisherPermissions: uq.withPublisherPermissions.Clone(), + withReviews: uq.withReviews.Clone(), // clone intermediate query. sql: uq.sql.Clone(), path: uq.path, @@ -294,6 +319,17 @@ func (uq *UserQuery) WithPublisherPermissions(opts ...func(*PublisherPermissionQ return uq } +// WithReviews tells the query-builder to eager-load the nodes that are connected to +// the "reviews" edge. The optional arguments are used to configure the query builder of the edge. +func (uq *UserQuery) WithReviews(opts ...func(*NodeReviewQuery)) *UserQuery { + query := (&NodeReviewClient{config: uq.config}).Query() + for _, opt := range opts { + opt(query) + } + uq.withReviews = query + return uq +} + // GroupBy is used to group vertices by one or more fields/columns. // It is often used with aggregate functions, like: count, max, mean, min, sum. // @@ -372,8 +408,9 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e var ( nodes = []*User{} _spec = uq.querySpec() - loadedTypes = [1]bool{ + loadedTypes = [2]bool{ uq.withPublisherPermissions != nil, + uq.withReviews != nil, } ) _spec.ScanValues = func(columns []string) ([]any, error) { @@ -406,6 +443,13 @@ func (uq *UserQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*User, e return nil, err } } + if query := uq.withReviews; query != nil { + if err := uq.loadReviews(ctx, query, nodes, + func(n *User) { n.Edges.Reviews = []*NodeReview{} }, + func(n *User, e *NodeReview) { n.Edges.Reviews = append(n.Edges.Reviews, e) }); err != nil { + return nil, err + } + } return nodes, nil } @@ -439,6 +483,36 @@ func (uq *UserQuery) loadPublisherPermissions(ctx context.Context, query *Publis } return nil } +func (uq *UserQuery) loadReviews(ctx context.Context, query *NodeReviewQuery, nodes []*User, init func(*User), assign func(*User, *NodeReview)) error { + fks := make([]driver.Value, 0, len(nodes)) + nodeids := make(map[string]*User) + for i := range nodes { + fks = append(fks, nodes[i].ID) + nodeids[nodes[i].ID] = nodes[i] + if init != nil { + init(nodes[i]) + } + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(nodereview.FieldUserID) + } + query.Where(predicate.NodeReview(func(s *sql.Selector) { + s.Where(sql.InValues(s.C(user.ReviewsColumn), fks...)) + })) + neighbors, err := query.All(ctx) + if err != nil { + return err + } + for _, n := range neighbors { + fk := n.UserID + node, ok := nodeids[fk] + if !ok { + return fmt.Errorf(`unexpected referenced foreign-key "user_id" returned %v for node %v`, fk, n.ID) + } + assign(node, n) + } + return nil +} func (uq *UserQuery) sqlCount(ctx context.Context) (int, error) { _spec := uq.querySpec() @@ -553,6 +627,12 @@ func (uq *UserQuery) ForShare(opts ...sql.LockOption) *UserQuery { return uq } +// Modify adds a query modifier for attaching custom logic to queries. +func (uq *UserQuery) Modify(modifiers ...func(s *sql.Selector)) *UserSelect { + uq.modifiers = append(uq.modifiers, modifiers...) + return uq.Select() +} + // UserGroupBy is the group-by builder for User entities. type UserGroupBy struct { selector @@ -642,3 +722,9 @@ func (us *UserSelect) sqlScan(ctx context.Context, root *UserQuery, v any) error defer rows.Close() return sql.ScanSlice(rows, v) } + +// Modify adds a query modifier for attaching custom logic to queries. +func (us *UserSelect) Modify(modifiers ...func(s *sql.Selector)) *UserSelect { + us.modifiers = append(us.modifiers, modifiers...) + return us +} diff --git a/ent/user_update.go b/ent/user_update.go index 28e0135..66846ba 100644 --- a/ent/user_update.go +++ b/ent/user_update.go @@ -6,21 +6,25 @@ import ( "context" "errors" "fmt" + "registry-backend/ent/nodereview" "registry-backend/ent/predicate" "registry-backend/ent/publisherpermission" + "registry-backend/ent/schema" "registry-backend/ent/user" "time" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" "entgo.io/ent/schema/field" + "github.com/google/uuid" ) // UserUpdate is the builder for updating User entities. type UserUpdate struct { config - hooks []Hook - mutation *UserMutation + hooks []Hook + mutation *UserMutation + modifiers []func(*sql.UpdateBuilder) } // Where appends a list predicates to the UserUpdate builder. @@ -103,6 +107,20 @@ func (uu *UserUpdate) SetNillableIsAdmin(b *bool) *UserUpdate { return uu } +// SetStatus sets the "status" field. +func (uu *UserUpdate) SetStatus(sst schema.UserStatusType) *UserUpdate { + uu.mutation.SetStatus(sst) + return uu +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (uu *UserUpdate) SetNillableStatus(sst *schema.UserStatusType) *UserUpdate { + if sst != nil { + uu.SetStatus(*sst) + } + return uu +} + // AddPublisherPermissionIDs adds the "publisher_permissions" edge to the PublisherPermission entity by IDs. func (uu *UserUpdate) AddPublisherPermissionIDs(ids ...int) *UserUpdate { uu.mutation.AddPublisherPermissionIDs(ids...) @@ -118,6 +136,21 @@ func (uu *UserUpdate) AddPublisherPermissions(p ...*PublisherPermission) *UserUp return uu.AddPublisherPermissionIDs(ids...) } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by IDs. +func (uu *UserUpdate) AddReviewIDs(ids ...uuid.UUID) *UserUpdate { + uu.mutation.AddReviewIDs(ids...) + return uu +} + +// AddReviews adds the "reviews" edges to the NodeReview entity. +func (uu *UserUpdate) AddReviews(n ...*NodeReview) *UserUpdate { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return uu.AddReviewIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (uu *UserUpdate) Mutation() *UserMutation { return uu.mutation @@ -144,6 +177,27 @@ func (uu *UserUpdate) RemovePublisherPermissions(p ...*PublisherPermission) *Use return uu.RemovePublisherPermissionIDs(ids...) } +// ClearReviews clears all "reviews" edges to the NodeReview entity. +func (uu *UserUpdate) ClearReviews() *UserUpdate { + uu.mutation.ClearReviews() + return uu +} + +// RemoveReviewIDs removes the "reviews" edge to NodeReview entities by IDs. +func (uu *UserUpdate) RemoveReviewIDs(ids ...uuid.UUID) *UserUpdate { + uu.mutation.RemoveReviewIDs(ids...) + return uu +} + +// RemoveReviews removes "reviews" edges to NodeReview entities. +func (uu *UserUpdate) RemoveReviews(n ...*NodeReview) *UserUpdate { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return uu.RemoveReviewIDs(ids...) +} + // Save executes the query and returns the number of nodes affected by the update operation. func (uu *UserUpdate) Save(ctx context.Context) (int, error) { uu.defaults() @@ -180,7 +234,26 @@ func (uu *UserUpdate) defaults() { } } +// check runs all checks and user-defined validators on the builder. +func (uu *UserUpdate) check() error { + if v, ok := uu.mutation.Status(); ok { + if err := user.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "User.status": %w`, err)} + } + } + return nil +} + +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (uu *UserUpdate) Modify(modifiers ...func(u *sql.UpdateBuilder)) *UserUpdate { + uu.modifiers = append(uu.modifiers, modifiers...) + return uu +} + func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { + if err := uu.check(); err != nil { + return n, err + } _spec := sqlgraph.NewUpdateSpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeString)) if ps := uu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { @@ -210,6 +283,9 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { if value, ok := uu.mutation.IsAdmin(); ok { _spec.SetField(user.FieldIsAdmin, field.TypeBool, value) } + if value, ok := uu.mutation.Status(); ok { + _spec.SetField(user.FieldStatus, field.TypeEnum, value) + } if uu.mutation.PublisherPermissionsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -255,6 +331,52 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if uu.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.RemovedReviewsIDs(); len(nodes) > 0 && !uu.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uu.mutation.ReviewsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _spec.AddModifiers(uu.modifiers...) if n, err = sqlgraph.UpdateNodes(ctx, uu.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{user.Label} @@ -270,9 +392,10 @@ func (uu *UserUpdate) sqlSave(ctx context.Context) (n int, err error) { // UserUpdateOne is the builder for updating a single User entity. type UserUpdateOne struct { config - fields []string - hooks []Hook - mutation *UserMutation + fields []string + hooks []Hook + mutation *UserMutation + modifiers []func(*sql.UpdateBuilder) } // SetUpdateTime sets the "update_time" field. @@ -349,6 +472,20 @@ func (uuo *UserUpdateOne) SetNillableIsAdmin(b *bool) *UserUpdateOne { return uuo } +// SetStatus sets the "status" field. +func (uuo *UserUpdateOne) SetStatus(sst schema.UserStatusType) *UserUpdateOne { + uuo.mutation.SetStatus(sst) + return uuo +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (uuo *UserUpdateOne) SetNillableStatus(sst *schema.UserStatusType) *UserUpdateOne { + if sst != nil { + uuo.SetStatus(*sst) + } + return uuo +} + // AddPublisherPermissionIDs adds the "publisher_permissions" edge to the PublisherPermission entity by IDs. func (uuo *UserUpdateOne) AddPublisherPermissionIDs(ids ...int) *UserUpdateOne { uuo.mutation.AddPublisherPermissionIDs(ids...) @@ -364,6 +501,21 @@ func (uuo *UserUpdateOne) AddPublisherPermissions(p ...*PublisherPermission) *Us return uuo.AddPublisherPermissionIDs(ids...) } +// AddReviewIDs adds the "reviews" edge to the NodeReview entity by IDs. +func (uuo *UserUpdateOne) AddReviewIDs(ids ...uuid.UUID) *UserUpdateOne { + uuo.mutation.AddReviewIDs(ids...) + return uuo +} + +// AddReviews adds the "reviews" edges to the NodeReview entity. +func (uuo *UserUpdateOne) AddReviews(n ...*NodeReview) *UserUpdateOne { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return uuo.AddReviewIDs(ids...) +} + // Mutation returns the UserMutation object of the builder. func (uuo *UserUpdateOne) Mutation() *UserMutation { return uuo.mutation @@ -390,6 +542,27 @@ func (uuo *UserUpdateOne) RemovePublisherPermissions(p ...*PublisherPermission) return uuo.RemovePublisherPermissionIDs(ids...) } +// ClearReviews clears all "reviews" edges to the NodeReview entity. +func (uuo *UserUpdateOne) ClearReviews() *UserUpdateOne { + uuo.mutation.ClearReviews() + return uuo +} + +// RemoveReviewIDs removes the "reviews" edge to NodeReview entities by IDs. +func (uuo *UserUpdateOne) RemoveReviewIDs(ids ...uuid.UUID) *UserUpdateOne { + uuo.mutation.RemoveReviewIDs(ids...) + return uuo +} + +// RemoveReviews removes "reviews" edges to NodeReview entities. +func (uuo *UserUpdateOne) RemoveReviews(n ...*NodeReview) *UserUpdateOne { + ids := make([]uuid.UUID, len(n)) + for i := range n { + ids[i] = n[i].ID + } + return uuo.RemoveReviewIDs(ids...) +} + // Where appends a list predicates to the UserUpdate builder. func (uuo *UserUpdateOne) Where(ps ...predicate.User) *UserUpdateOne { uuo.mutation.Where(ps...) @@ -439,7 +612,26 @@ func (uuo *UserUpdateOne) defaults() { } } +// check runs all checks and user-defined validators on the builder. +func (uuo *UserUpdateOne) check() error { + if v, ok := uuo.mutation.Status(); ok { + if err := user.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "User.status": %w`, err)} + } + } + return nil +} + +// Modify adds a statement modifier for attaching custom logic to the UPDATE statement. +func (uuo *UserUpdateOne) Modify(modifiers ...func(u *sql.UpdateBuilder)) *UserUpdateOne { + uuo.modifiers = append(uuo.modifiers, modifiers...) + return uuo +} + func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { + if err := uuo.check(); err != nil { + return _node, err + } _spec := sqlgraph.NewUpdateSpec(user.Table, user.Columns, sqlgraph.NewFieldSpec(user.FieldID, field.TypeString)) id, ok := uuo.mutation.ID() if !ok { @@ -486,6 +678,9 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) if value, ok := uuo.mutation.IsAdmin(); ok { _spec.SetField(user.FieldIsAdmin, field.TypeBool, value) } + if value, ok := uuo.mutation.Status(); ok { + _spec.SetField(user.FieldStatus, field.TypeEnum, value) + } if uuo.mutation.PublisherPermissionsCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -531,6 +726,52 @@ func (uuo *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) } _spec.Edges.Add = append(_spec.Edges.Add, edge) } + if uuo.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.RemovedReviewsIDs(); len(nodes) > 0 && !uuo.mutation.ReviewsCleared() { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Clear = append(_spec.Edges.Clear, edge) + } + if nodes := uuo.mutation.ReviewsIDs(); len(nodes) > 0 { + edge := &sqlgraph.EdgeSpec{ + Rel: sqlgraph.O2M, + Inverse: false, + Table: user.ReviewsTable, + Columns: []string{user.ReviewsColumn}, + Bidi: false, + Target: &sqlgraph.EdgeTarget{ + IDSpec: sqlgraph.NewFieldSpec(nodereview.FieldID, field.TypeUUID), + }, + } + for _, k := range nodes { + edge.Target.Nodes = append(edge.Target.Nodes, k) + } + _spec.Edges.Add = append(_spec.Edges.Add, edge) + } + _spec.AddModifiers(uuo.modifiers...) _node = &User{config: uuo.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/gateways/algolia/algolia.go b/gateways/algolia/algolia.go new file mode 100644 index 0000000..0bf307f --- /dev/null +++ b/gateways/algolia/algolia.go @@ -0,0 +1,129 @@ +package algolia + +import ( + "context" + "fmt" + "os" + "registry-backend/ent" + + "github.com/algolia/algoliasearch-client-go/v3/algolia/search" + "github.com/rs/zerolog/log" +) + +// AlgoliaService defines the interface for interacting with Algolia search. +type AlgoliaService interface { + IndexNodes(ctx context.Context, nodes ...*ent.Node) error + SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) + DeleteNode(ctx context.Context, node *ent.Node) error +} + +// Ensure algolia struct implements AlgoliaService interface +var _ AlgoliaService = (*algolia)(nil) + +// algolia struct holds the Algolia client. +type algolia struct { + client *search.Client +} + +// New creates a new Algolia service with the provided app ID and API key. +func New(appid, apikey string) (AlgoliaService, error) { + return &algolia{ + client: search.NewClient(appid, apikey), + }, nil +} + +// NewFromEnv creates a new Algolia service using environment variables for configuration. +func NewFromEnv() (AlgoliaService, error) { + appid, ok := os.LookupEnv("ALGOLIA_APP_ID") + if !ok { + return nil, fmt.Errorf("required env variable ALGOLIA_APP_ID is not set") + } + apikey, ok := os.LookupEnv("ALGOLIA_API_KEY") + if !ok { + return nil, fmt.Errorf("required env variable ALGOLIA_API_KEY is not set") + } + return New(appid, apikey) +} + +// IndexNodes indexes the provided nodes in Algolia. +func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { + index := a.client.InitIndex("nodes_index") + objects := make([]struct { + ObjectID string `json:"objectID"` + *ent.Node + }, len(nodes)) + + for i, n := range nodes { + objects[i] = struct { + ObjectID string `json:"objectID"` + *ent.Node + }{ + ObjectID: n.ID, + Node: n, + } + } + + res, err := index.SaveObjects(objects) + if err != nil { + return fmt.Errorf("failed to index nodes: %w", err) + } + + return res.Wait() +} + +// SearchNodes searches for nodes in Algolia matching the query. +func (a *algolia) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) { + index := a.client.InitIndex("nodes_index") + res, err := index.Search(query, opts...) + if err != nil { + return nil, fmt.Errorf("failed to search nodes: %w", err) + } + + var nodes []*ent.Node + if err := res.UnmarshalHits(&nodes); err != nil { + return nil, fmt.Errorf("failed to unmarshal search results: %w", err) + } + return nodes, nil +} + +// DeleteNode deletes the specified node from Algolia. +func (a *algolia) DeleteNode(ctx context.Context, node *ent.Node) error { + index := a.client.InitIndex("nodes_index") + res, err := index.DeleteObject(node.ID) + if err != nil { + return fmt.Errorf("failed to delete node: %w", err) + } + return res.Wait() +} + +var _ AlgoliaService = (*algolianoop)(nil) + +type algolianoop struct{} + +func NewFromEnvOrNoop() (AlgoliaService, error) { + id := os.Getenv("ALGOLIA_APP_ID") + key := os.Getenv("ALGOLIA_API_KEY") + if id == "" && key == "" { + return &algolianoop{}, nil + } + + return NewFromEnv() +} + +// DeleteNode implements AlgoliaService. +func (a *algolianoop) DeleteNode(ctx context.Context, node *ent.Node) error { + log.Ctx(ctx).Info().Msgf("algolia noop: delete node: %s", node.ID) + return nil +} + +// IndexNodes implements AlgoliaService. +func (a *algolianoop) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { + log.Ctx(ctx).Info().Msgf("algolia noop: index nodes: %d number of nodes", len(nodes)) + return nil +} + +// SearchNodes implements AlgoliaService. +func (a *algolianoop) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) { + log.Ctx(ctx).Info().Msgf("algolia noop: search nodes: %s", query) + return nil, nil +} diff --git a/gateways/algolia/algolia_test.go b/gateways/algolia/algolia_test.go new file mode 100644 index 0000000..577639e --- /dev/null +++ b/gateways/algolia/algolia_test.go @@ -0,0 +1,55 @@ +package algolia + +import ( + "context" + "os" + "registry-backend/ent" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIndex(t *testing.T) { + _, ok := os.LookupEnv("ALGOLIA_APP_ID") + if !ok { + t.Skip("Required env variables `ALGOLIA_APP_ID` is not set") + } + _, ok = os.LookupEnv("ALGOLIA_API_KEY") + if !ok { + t.Skip("Required env variables `ALGOLIA_API_KEY` is not set") + } + + algolia, err := NewFromEnv() + require.NoError(t, err) + + ctx := context.Background() + node := &ent.Node{ + ID: uuid.NewString(), + Name: t.Name() + "-" + uuid.NewString(), + TotalStar: 98, + TotalReview: 20, + } + for i := 0; i < 10; i++ { + err = algolia.IndexNodes(ctx, node) + require.NoError(t, err) + } + + <-time.After(time.Second * 10) + nodes, err := algolia.SearchNodes(ctx, node.Name) + require.NoError(t, err) + require.Len(t, nodes, 1) + assert.Equal(t, node, nodes[0]) + +} + +func TestNoop(t *testing.T) { + t.Setenv("ALGOLIA_APP_ID", "") + t.Setenv("ALGOLIA_API_KEY", "") + a, err := NewFromEnvOrNoop() + require.NoError(t, err) + require.NoError(t, a.IndexNodes(context.Background(), &ent.Node{})) + require.NoError(t, a.DeleteNode(context.Background(), &ent.Node{})) +} diff --git a/gateways/discord/discord.go b/gateways/discord/discord.go new file mode 100644 index 0000000..e1779cd --- /dev/null +++ b/gateways/discord/discord.go @@ -0,0 +1,71 @@ +package discord + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "registry-backend/config" +) + +type DiscordService interface { + SendSecurityCouncilMessage(msg string) error +} + +type DripDiscordService struct { + securityDiscordChannelWebhook string + config *config.Config +} + +func NewDiscordService(config *config.Config) *DripDiscordService { + return &DripDiscordService{ + config: config, + securityDiscordChannelWebhook: config.DiscordSecurityChannelWebhook, + } +} + +type discordRequestBody struct { + Content string `json:"content"` +} + +func (s *DripDiscordService) SendSecurityCouncilMessage(msg string) error { + if s.config.DripEnv == "prod" { + return sendDiscordNotification(msg, s.securityDiscordChannelWebhook) + } else { + println("Skipping sending message to Discord in non-prod environment. " + msg) + } + return nil +} + +func sendDiscordNotification(msg string, discordWebhookURL string) error { + if discordWebhookURL == "" { + return fmt.Errorf("no Discord webhook URL provided, skipping sending message to Discord") + } + + body, err := json.Marshal(discordRequestBody{Content: msg}) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, discordWebhookURL, bytes.NewBuffer(body)) + if err != nil { + return err + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return err + } + + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + // You can handle or log the HTTP error status code here + return fmt.Errorf("request to Discord returned error status: %d", resp.StatusCode) + } + + return nil +} diff --git a/gateways/slack/slack.go b/gateways/slack/slack.go index 3bdcd16..a54cb39 100644 --- a/gateways/slack/slack.go +++ b/gateways/slack/slack.go @@ -5,19 +5,23 @@ import ( "encoding/json" "fmt" "net/http" + "registry-backend/config" ) -const registrySlackChannelWebhook = "https://hooks.slack.com/services/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" - type SlackService interface { SendRegistryMessageToSlack(msg string) error } type DripSlackService struct { + registrySlackChannelWebhook string + config *config.Config } -func NewSlackService() *DripSlackService { - return &DripSlackService{} +func NewSlackService(config *config.Config) *DripSlackService { + return &DripSlackService{ + config: config, + registrySlackChannelWebhook: config.SlackRegistryChannelWebhook, + } } @@ -26,10 +30,18 @@ type slackRequestBody struct { } func (s *DripSlackService) SendRegistryMessageToSlack(msg string) error { - return sendSlackNotification(msg, registrySlackChannelWebhook) + if s.config.DripEnv == "prod" { + return sendSlackNotification(msg, s.registrySlackChannelWebhook) + } + return nil } func sendSlackNotification(msg string, slackWebhookURL string) error { + if slackWebhookURL == "" { + println("No Slack webhook URL provided, skipping sending message to Slack") + return nil + } + body, err := json.Marshal(slackRequestBody{Text: msg}) if err != nil { return err diff --git a/go.mod b/go.mod index 2578032..8f665b5 100644 --- a/go.mod +++ b/go.mod @@ -8,15 +8,17 @@ require ( entgo.io/ent v0.13.1 firebase.google.com/go v3.13.0+incompatible github.com/Masterminds/semver/v3 v3.2.1 + github.com/algolia/algoliasearch-client-go/v3 v3.31.1 github.com/deepmap/oapi-codegen/v2 v2.1.0 github.com/getkin/kin-openapi v0.123.0 + github.com/golang-jwt/jwt/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/labstack/echo/v4 v4.11.4 github.com/lib/pq v1.10.9 github.com/mixpanel/mixpanel-go v1.2.1 github.com/oapi-codegen/runtime v1.1.1 github.com/rs/zerolog v1.32.0 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.28.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.28.0 google.golang.org/api v0.165.0 @@ -24,6 +26,8 @@ require ( google.golang.org/protobuf v1.32.0 ) +require github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + require ( ariga.io/atlas v0.19.1-0.20240203083654-5948b60a8e43 // indirect cloud.google.com/go v0.112.0 // indirect @@ -56,7 +60,6 @@ require ( github.com/go-openapi/jsonpointer v0.20.2 // indirect github.com/go-openapi/swag v0.22.9 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-cmp v0.6.0 // indirect @@ -90,7 +93,7 @@ require ( github.com/shirou/gopsutil/v3 v3.23.12 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/stretchr/objx v0.5.0 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect @@ -103,16 +106,16 @@ require ( go.opentelemetry.io/otel v1.23.1 // indirect go.opentelemetry.io/otel/metric v1.23.1 // indirect go.opentelemetry.io/otel/trace v1.23.1 // indirect - golang.org/x/crypto v0.19.0 // indirect - golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea // indirect - golang.org/x/mod v0.15.0 // indirect - golang.org/x/net v0.21.0 // indirect + golang.org/x/crypto v0.22.0 // indirect + golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/net v0.24.0 // indirect golang.org/x/oauth2 v0.17.0 // indirect - golang.org/x/sync v0.6.0 // indirect - golang.org/x/sys v0.17.0 // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.19.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.18.0 // indirect + golang.org/x/tools v0.20.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240213162025-012b6fc9bca9 // indirect diff --git a/go.sum b/go.sum index 1862205..77a5a35 100644 --- a/go.sum +++ b/go.sum @@ -39,6 +39,8 @@ github.com/Microsoft/hcsshim v0.11.4/go.mod h1:smjE4dvqPX9Zldna+t5FG3rnoHhaB7QYx github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/algolia/algoliasearch-client-go/v3 v3.31.1 h1:xXA/RK4/EuXyUCgAXUB7Ala9T7sGMeNqlU2SIy7V/qY= +github.com/algolia/algoliasearch-client-go/v3 v3.31.1/go.mod h1:i7tLoP7TYDmHX3Q7vkIOL4syVse/k5VJ+k0i8WqFiJk= github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= @@ -103,6 +105,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -250,16 +254,19 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/testcontainers/testcontainers-go v0.28.0 h1:1HLm9qm+J5VikzFDYhOd+Zw12NtOl+8drH2E8nTY1r8= github.com/testcontainers/testcontainers-go v0.28.0/go.mod h1:COlDpUXbwW3owtpMkEB1zo9gwb1CoKVKlyrVPejF4AU= github.com/testcontainers/testcontainers-go/modules/postgres v0.28.0 h1:ff0s4JdYIdNAVSi/SrpN2Pdt1f+IjIw3AKjbHau8Un4= @@ -305,11 +312,11 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4= -golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= @@ -318,8 +325,8 @@ golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPI golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.15.0 h1:SernR4v+D55NyBH2QiEQrlBAnj1ECL6AGrA5+dPaMY8= -golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -331,8 +338,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ= golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA= @@ -342,8 +349,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -362,8 +369,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -383,8 +390,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= -golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= +golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= +golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -432,6 +439,7 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/integration-tests/ban_test.go b/integration-tests/ban_test.go new file mode 100644 index 0000000..7a0b21d --- /dev/null +++ b/integration-tests/ban_test.go @@ -0,0 +1,304 @@ +package integration + +import ( + "context" + "net/http" + "registry-backend/config" + "registry-backend/drip" + "registry-backend/ent/schema" + "registry-backend/mock/gateways" + "registry-backend/server/implementation" + drip_authorization "registry-backend/server/middleware/authorization" + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestBan(t *testing.T) { + clientCtx := context.Background() + client, cleanup := setupDB(t, clientCtx) + defer cleanup() + + // Initialize the Service + mockStorageService := new(gateways.MockStorageService) + mockSlackService := new(gateways.MockSlackService) + mockDiscordService := new(gateways.MockDiscordService) + mockSlackService. + On("SendRegistryMessageToSlack", mock.Anything). + Return(nil) // Do nothing for all slack messsage calls. + mockAlgolia := new(gateways.MockAlgoliaService) + mockAlgolia. + On("IndexNodes", mock.Anything, mock.Anything). + Return(nil) + + impl := implementation.NewStrictServerImplementation( + client, &config.Config{}, mockStorageService, mockSlackService, mockDiscordService, mockAlgolia) + + authz := drip_authorization.NewAuthorizationManager(client, impl.RegistryService).AuthorizationMiddleware() + + t.Run("Publisher", func(t *testing.T) { + t.Run("Ban", func(t *testing.T) { + ctx, user := setUpTest(client) + + publisherId := "test-publisher" + description := "test-description" + source_code_repo := "test-source-code-repo" + website := "test-website" + support := "test-support" + logo := "test-logo" + name := "test-name" + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: &drip.Publisher{ + Id: &publisherId, + Description: &description, + SourceCodeRepo: &source_code_repo, + Website: &website, + Support: &support, + Logo: &logo, + Name: &name, + }, + }) + require.NoError(t, err, "should return created publisher") + + nodeId := "test-node" + nodeDescription := "test-node-description" + nodeAuthor := "test-node-author" + nodeLicense := "test-node-license" + nodeName := "test-node-name" + nodeTags := []string{"test-node-tag"} + icon := "https://wwww.github.com/test-icon.svg" + githubUrl := "https://www.github.com/test-github-url" + _, err = withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ + PublisherId: publisherId, + Body: &drip.Node{ + Id: &nodeId, + Name: &nodeName, + Description: &nodeDescription, + Author: &nodeAuthor, + License: &nodeLicense, + Tags: &nodeTags, + Icon: &icon, + Repository: &githubUrl, + }, + }) + require.NoError(t, err, "should return created node") + + t.Run("By Non Admin", func(t *testing.T) { + ctx, _ := setUpTest(client) + res, err := withMiddleware(authz, impl.BanPublisher)(ctx, drip.BanPublisherRequestObject{PublisherId: publisherId}) + require.NoError(t, err, "should not ban publisher") + require.IsType(t, drip.BanPublisher403JSONResponse{}, res) + }) + + t.Run("By Admin", func(t *testing.T) { + ctx, admin := setUpTest(client) + err = admin.Update().SetIsAdmin(true).Exec(clientCtx) + require.NoError(t, err) + _, err = withMiddleware(authz, impl.BanPublisher)(ctx, drip.BanPublisherRequestObject{PublisherId: publisherId}) + require.NoError(t, err) + + pub, err := client.Publisher.Get(ctx, publisherId) + require.NoError(t, err) + assert.Equal(t, schema.PublisherStatusTypeBanned, pub.Status, "should ban publisher") + user, err := client.User.Get(ctx, user.ID) + require.NoError(t, err) + assert.Equal(t, schema.UserStatusTypeBanned, user.Status, "should ban user") + node, err := client.Node.Get(ctx, nodeId) + require.NoError(t, err) + assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") + }) + }) + + t.Run("Access", func(t *testing.T) { + testtable := []struct { + name string + invoke func(ctx context.Context) error + }{ + { + name: "CreatePublisher", + invoke: func(ctx context.Context) error { + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{Body: &drip.Publisher{}}) + return err + }, + }, + { + name: "DeleteNodeVersion", + invoke: func(ctx context.Context) error { + _, err := withMiddleware(authz, impl.DeleteNodeVersion)(ctx, drip.DeleteNodeVersionRequestObject{}) + return err + }, + }, + } + + t.Run("Banned", func(t *testing.T) { + ctxBanned, testUserBanned := setUpTest(client) + err := testUserBanned.Update().SetStatus(schema.UserStatusTypeBanned).Exec(ctxBanned) + require.NoError(t, err) + for _, tt := range testtable { + t.Run(tt.name, func(t *testing.T) { + err = tt.invoke(ctxBanned) + require.Error(t, err, "should return error") + require.IsType(t, &echo.HTTPError{}, err, "should return echo http error") + echoErr := err.(*echo.HTTPError) + assert.Equal(t, http.StatusForbidden, echoErr.Code, "should return 403") + }) + } + }) + + t.Run("Not Banned", func(t *testing.T) { + ctx, _ := setUpTest(client) + for _, tt := range testtable { + t.Run(tt.name, func(t *testing.T) { + err := tt.invoke(ctx) + _, ok := err.(*echo.HTTPError) + assert.False(t, ok, err, "should pass the authorization middleware") + }) + } + }) + }) + }) + + t.Run("Node", func(t *testing.T) { + ctx, _ := setUpTest(client) + + publisherId := "test-publisher-1" + description := "test-description" + source_code_repo := "test-source-code-repo" + website := "test-website" + support := "test-support" + logo := "test-logo" + name := "test-name" + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: &drip.Publisher{ + Id: &publisherId, + Description: &description, + SourceCodeRepo: &source_code_repo, + Website: &website, + Support: &support, + Logo: &logo, + Name: &name, + }, + }) + require.NoError(t, err, "should return created publisher") + + nodeId := "test-node-1" + nodeDescription := "test-node-description" + nodeAuthor := "test-node-author" + nodeLicense := "test-node-license" + nodeName := "test-node-name" + nodeTags := []string{"test-node-tag"} + icon := "https://wwww.github.com/test-icon.svg" + githubUrl := "https://www.github.com/test-github-url" + _, err = withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ + PublisherId: publisherId, + Body: &drip.Node{ + Id: &nodeId, + Name: &nodeName, + Description: &nodeDescription, + Author: &nodeAuthor, + License: &nodeLicense, + Tags: &nodeTags, + Icon: &icon, + Repository: &githubUrl, + }, + }) + require.NoError(t, err, "should return created node") + + tokenName := "name" + tokenDescription := "name" + res, err := withMiddleware(authz, impl.CreatePersonalAccessToken)(ctx, drip.CreatePersonalAccessTokenRequestObject{ + PublisherId: publisherId, + Body: &drip.PersonalAccessToken{ + Name: &tokenName, + Description: &tokenDescription, + }, + }) + require.NoError(t, err, "should return created token") + require.IsType(t, drip.CreatePersonalAccessToken201JSONResponse{}, res) + pat := res.(drip.CreatePersonalAccessToken201JSONResponse).Token + + t.Run("Ban", func(t *testing.T) { + t.Run("By Non Admin", func(t *testing.T) { + ctx, _ := setUpTest(client) + res, err := withMiddleware(authz, impl.BanPublisherNode)(ctx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.NoError(t, err, "should not ban publisher node") + require.IsType(t, drip.BanPublisherNode403JSONResponse{}, res) + }) + + t.Run("By Admin", func(t *testing.T) { + ctx, admin := setUpTest(client) + err = admin.Update().SetIsAdmin(true).Exec(clientCtx) + require.NoError(t, err) + _, err = withMiddleware(authz, impl.BanPublisherNode)(ctx, drip.BanPublisherNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.NoError(t, err) + + node, err := client.Node.Get(ctx, nodeId) + require.NoError(t, err) + assert.Equal(t, schema.NodeStatusBanned, node.Status, "should ban node") + }) + }) + + t.Run("Operate", func(t *testing.T) { + t.Run("Get", func(t *testing.T) { + f := withMiddleware(authz, impl.GetNode) + _, err := f(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("Update", func(t *testing.T) { + f := withMiddleware(authz, impl.UpdateNode) + _, err := f(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("ListNodeVersion", func(t *testing.T) { + f := withMiddleware(authz, impl.ListNodeVersions) + _, err := f(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("PublishNodeVersion", func(t *testing.T) { + f := withMiddleware(authz, impl.PublishNodeVersion) + _, err := f(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: publisherId, NodeId: nodeId, + Body: &drip.PublishNodeVersionJSONRequestBody{PersonalAccessToken: *pat}, + }) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("InstallNode", func(t *testing.T) { + f := withMiddleware(authz, impl.InstallNode) + _, err := f(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) + require.Error(t, err) + require.IsType(t, &echo.HTTPError{}, err) + assert.Equal(t, err.(*echo.HTTPError).Code, http.StatusForbidden) + }) + t.Run("SearchNodes", func(t *testing.T) { + f := withMiddleware(authz, impl.SearchNodes) + res, err := f(ctx, drip.SearchNodesRequestObject{ + Params: drip.SearchNodesParams{}, + }) + require.NoError(t, err) + require.IsType(t, drip.SearchNodes200JSONResponse{}, res) + require.Empty(t, res.(drip.SearchNodes200JSONResponse).Nodes) + + res, err = f(ctx, drip.SearchNodesRequestObject{ + Params: drip.SearchNodesParams{IncludeBanned: proto.Bool(true)}, + }) + require.NoError(t, err) + require.IsType(t, drip.SearchNodes200JSONResponse{}, res) + require.NotEmpty(t, res.(drip.SearchNodes200JSONResponse).Nodes) + }) + }) + }) + +} diff --git a/integration-tests/ci_cd_integration_test.go b/integration-tests/ci_cd_integration_test.go new file mode 100644 index 0000000..24fce1e --- /dev/null +++ b/integration-tests/ci_cd_integration_test.go @@ -0,0 +1,174 @@ +package integration + +import ( + "context" + "fmt" + "registry-backend/config" + "registry-backend/drip" + "registry-backend/ent/gitcommit" + "registry-backend/mock/gateways" + "registry-backend/server/implementation" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestCICD(t *testing.T) { + clientCtx := context.Background() + client, cleanup := setupDB(t, clientCtx) + defer cleanup() + + // Initialize the Service + mockStorageService := new(gateways.MockStorageService) + mockSlackService := new(gateways.MockSlackService) + mockDiscordService := new(gateways.MockDiscordService) + mockSlackService. + On("SendRegistryMessageToSlack", mock.Anything). + Return(nil) // Do nothing for all slack messsage calls. + mockAlgolia := new(gateways.MockAlgoliaService) + mockAlgolia. + On("IndexNodes", mock.Anything, mock.Anything). + Return(nil) + impl := implementation.NewStrictServerImplementation( + client, &config.Config{}, mockStorageService, mockSlackService, mockDiscordService, mockAlgolia) + + ctx := context.Background() + now := time.Now() + anHourAgo := now.Add(-1 * time.Hour) + avgVram := 2132 + + body := &drip.PostUploadArtifactJSONRequestBody{ + Repo: "github.com/comfy/service", + BranchName: "develop", + CommitHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + CommitMessage: "new commit", + CommitTime: anHourAgo.Format(time.RFC3339), + JobId: "018fbe20-88a6-7d31-a194-eee8e2509da3", + RunId: "018fbe37-a7a8-74a3-a377-8d70d54f54d8", + Os: "linux", + WorkflowName: "devops", + CudaVersion: proto.String("1.0.0"), + BucketName: proto.String("comfy-dev-bucket"), + OutputFilesGcsPaths: proto.String("comfy-dev-file"), + ComfyLogsGcsPath: proto.String("comfy-dev-log"), + StartTime: anHourAgo.Unix(), + EndTime: now.Unix(), + PrNumber: "123", + PythonVersion: "3.8", + PytorchVersion: proto.String("1.0.0"), + JobTriggerUser: "comfy", + Author: "robin", + AvgVram: &avgVram, + ComfyRunFlags: proto.String("comfy"), + Status: drip.WorkflowRunStatusStarted, + MachineStats: &drip.MachineStats{ + CpuCapacity: proto.String("2.0"), + InitialCpu: proto.String("1.0"), + InitialDisk: proto.String("1.0"), + DiskCapacity: proto.String("2.0"), + InitialRam: proto.String("1.0"), + MemoryCapacity: proto.String("2.0"), + OsVersion: proto.String("Ubuntu 24.10"), + PipFreeze: proto.String("requests==1.0.0"), + MachineName: proto.String("comfy-dev"), + GpuType: proto.String("NVIDIA Tesla V100"), + }, + } + + t.Run("Post Upload Artifact", func(t *testing.T) { + body := *body + body.JobId = "018fbe4a-2844-7c2e-87f1-311605292452" + body.RunId = "018fbe4a-5b1c-7a51-8e26-53e77961ee06" + res, err := impl.PostUploadArtifact(ctx, drip.PostUploadArtifactRequestObject{Body: &body}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.PostUploadArtifact200JSONResponse{}, res, "should return 200") + }) + + t.Run("Re Post Upload Artifact", func(t *testing.T) { + res, err := impl.PostUploadArtifact(ctx, drip.PostUploadArtifactRequestObject{Body: body}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.PostUploadArtifact200JSONResponse{}, res, "should return 200") + }) + + t.Run("Get Git Commit", func(t *testing.T) { + expectedAvgVram := 2132 + expectedPeakVram := 0 + git, err := client.GitCommit.Query().Where(gitcommit.CommitHashEQ(body.CommitHash)).First(ctx) + require.NoError(t, err) + + res, err := impl.GetGitcommit(ctx, drip.GetGitcommitRequestObject{Params: drip.GetGitcommitParams{ + CommitId: proto.String(git.ID.String()), + OperatingSystem: &body.Os, + WorkflowName: &body.WorkflowName, + Branch: &body.BranchName, + RepoName: &body.Repo, + }}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.GetGitcommit200JSONResponse{}, res) + res200 := res.(drip.GetGitcommit200JSONResponse) + require.Len(t, *res200.JobResults, 1) + assert.Equal(t, *res200.TotalNumberOfPages, 1) + assert.Equal(t, drip.ActionJobResult{ + Id: (*res200.JobResults)[0].Id, + ActionRunId: &body.RunId, + ActionJobId: &body.JobId, + CommitHash: &body.CommitHash, + CommitId: proto.String(git.ID.String()), + CommitMessage: &body.CommitMessage, + CommitTime: proto.Int64(anHourAgo.Unix()), + EndTime: proto.Int64(now.Unix()), + GitRepo: &body.Repo, + OperatingSystem: &body.Os, + StartTime: proto.Int64(anHourAgo.Unix()), + WorkflowName: &body.WorkflowName, + JobTriggerUser: &body.JobTriggerUser, + AvgVram: &expectedAvgVram, + PeakVram: &expectedPeakVram, + PythonVersion: &body.PythonVersion, + Status: &body.Status, + PrNumber: &body.PrNumber, + CudaVersion: body.CudaVersion, + PytorchVersion: body.PytorchVersion, + Author: &body.Author, + ComfyRunFlags: body.ComfyRunFlags, + StorageFile: &drip.StorageFile{ + PublicUrl: proto.String(fmt.Sprintf("https://storage.googleapis.com/%s/%s", *body.BucketName, *body.OutputFilesGcsPaths)), + }, + MachineStats: body.MachineStats, + }, (*res200.JobResults)[0]) + }) + + t.Run("Get invalid Git Commit", func(t *testing.T) { + fakeID, _ := uuid.NewV7() + res, err := impl.GetGitcommit(ctx, drip.GetGitcommitRequestObject{Params: drip.GetGitcommitParams{ + CommitId: proto.String(fakeID.String())}}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.GetGitcommit200JSONResponse{}, res) + assert.Len(t, *res.(drip.GetGitcommit200JSONResponse).JobResults, 0) + }) + + t.Run("Get Branch", func(t *testing.T) { + res, err := impl.GetBranch(ctx, drip.GetBranchRequestObject{Params: drip.GetBranchParams{ + RepoName: body.Repo, + }}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.GetBranch200JSONResponse{}, res) + res200 := res.(drip.GetBranch200JSONResponse) + require.Len(t, *res200.Branches, 1, "should return corrent number of branches") + assert.Equal(t, body.BranchName, (*res200.Branches)[0], "should return correct branches") + }) + + t.Run("Get invalid branch", func(t *testing.T) { + res, err := impl.GetBranch(ctx, drip.GetBranchRequestObject{Params: drip.GetBranchParams{ + RepoName: "notexist", + }}) + require.NoError(t, err, "should return error") + assert.IsType(t, drip.GetBranch200JSONResponse{}, res) + assert.Len(t, *res.(drip.GetBranch200JSONResponse).Branches, 0, "should return empty branch") + }) +} diff --git a/integration-tests/registry_integration_test.go b/integration-tests/registry_integration_test.go index 87a7fe1..ec6fa50 100644 --- a/integration-tests/registry_integration_test.go +++ b/integration-tests/registry_integration_test.go @@ -2,15 +2,27 @@ package integration import ( "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" "registry-backend/config" "registry-backend/drip" "registry-backend/ent" + "registry-backend/ent/nodeversion" + "registry-backend/ent/schema" "registry-backend/mock/gateways" "registry-backend/server/implementation" + drip_authorization "registry-backend/server/middleware/authorization" + dripservices_registry "registry-backend/services/registry" "strings" "testing" + "time" - "github.com/rs/zerolog/log" + "github.com/labstack/echo/v4" + strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" + + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -25,642 +37,783 @@ func setUpTest(client *ent.Client) (context.Context, *ent.User) { return ctx, testUser } -func TestRegistry(t *testing.T) { - clientCtx := context.Background() - client, postgresContainer := setupDB(t, clientCtx) - // Cleanup - defer func() { - if err := postgresContainer.Terminate(clientCtx); err != nil { - log.Ctx(clientCtx).Error().Msgf("failed to terminate container: %s", err) - } - }() +func setUpAdminTest(client *ent.Client) (context.Context, *ent.User) { + ctx := context.Background() + testUser := createAdminUser(ctx, client) + ctx = decorateUserInContext(ctx, testUser) + return ctx, testUser +} + +func randomPublisher() *drip.Publisher { + suffix := uuid.New().String() + publisherId := "test-publisher-" + suffix + description := "test-description" + suffix + source_code_repo := "test-source-code-repo" + suffix + website := "test-website" + suffix + support := "test-support" + suffix + logo := "test-logo" + suffix + name := "test-name" + suffix + + return &drip.Publisher{ + Id: &publisherId, + Description: &description, + SourceCodeRepo: &source_code_repo, + Website: &website, + Support: &support, + Logo: &logo, + Name: &name, + } +} + +func randomNode() *drip.Node { + suffix := uuid.New().String() + nodeId := "test-node" + suffix + nodeDescription := "test-node-description" + suffix + nodeAuthor := "test-node-author" + suffix + nodeLicense := "test-node-license" + suffix + nodeName := "test-node-name" + suffix + nodeTags := []string{"test-node-tag"} + icon := "https://wwww.github.com/test-icon-" + suffix + ".svg" + githubUrl := "https://www.github.com/test-github-url-" + suffix + + return &drip.Node{ + Id: &nodeId, + Name: &nodeName, + Description: &nodeDescription, + Author: &nodeAuthor, + License: &nodeLicense, + Tags: &nodeTags, + Icon: &icon, + Repository: &githubUrl, + } +} + +func randomNodeVersion(revision int) *drip.NodeVersion { + suffix := uuid.New().String() + + version := fmt.Sprintf("1.0.%d", revision) + changelog := "test-changelog-" + suffix + dependencies := []string{"test-dependency" + suffix} + return &drip.NodeVersion{ + Version: &version, + Changelog: &changelog, + Dependencies: &dependencies, + } +} + +type mockedImpl struct { + *implementation.DripStrictServerImplementation + + mockStorageService *gateways.MockStorageService + mockSlackService *gateways.MockSlackService + mockDiscordService *gateways.MockDiscordService + mockAlgolia *gateways.MockAlgoliaService +} +func newMockedImpl(client *ent.Client, cfg *config.Config) (impl mockedImpl, authz strictecho.StrictEchoMiddlewareFunc) { // Initialize the Service mockStorageService := new(gateways.MockStorageService) + + mockDiscordService := new(gateways.MockDiscordService) + mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything). + Return(nil) // Do nothing for all discord messsage calls. + mockSlackService := new(gateways.MockSlackService) mockSlackService. On("SendRegistryMessageToSlack", mock.Anything). Return(nil) // Do nothing for all slack messsage calls. - impl := implementation.NewStrictServerImplementation( - client, &config.Config{}, mockStorageService, mockSlackService) - - t.Run("Publisher", func(t *testing.T) { - ctx, testUser := setUpTest(client) - publisherId := "test-publisher" - description := "test-description" - source_code_repo := "test-source-code-repo" - website := "test-website" - support := "test-support" - logo := "test-logo" - name := "test-name" - - t.Run("Create Publisher", func(t *testing.T) { - createPublisherResponse, err := impl.CreatePublisher(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - }, - }) - require.NoError(t, err, "should return created publisher") - require.NotNil(t, createPublisherResponse, "should return created publisher") - assert.Equal(t, publisherId, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).Id) - assert.Equal(t, description, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).Description) - assert.Equal(t, source_code_repo, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).SourceCodeRepo) - assert.Equal(t, website, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).Website) - assert.Equal(t, support, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).Support) - assert.Equal(t, logo, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).Logo) - }) - t.Run("Validate Publisher", func(t *testing.T) { - res, err := impl.ValidatePublisher(ctx, drip.ValidatePublisherRequestObject{ - Params: drip.ValidatePublisherParams{Username: name}, - }) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ValidatePublisher200JSONResponse{}, res, "should return 200") - require.True(t, *res.(drip.ValidatePublisher200JSONResponse).IsAvailable, "should be available") - }) + mockAlgolia := new(gateways.MockAlgoliaService) + mockAlgolia. + On("IndexNodes", mock.Anything, mock.Anything). + Return(nil). + On("DeleteNode", mock.Anything, mock.Anything). + Return(nil) + + impl = mockedImpl{ + DripStrictServerImplementation: implementation.NewStrictServerImplementation( + client, cfg, mockStorageService, mockSlackService, mockDiscordService, mockAlgolia), + mockStorageService: mockStorageService, + mockSlackService: mockSlackService, + mockDiscordService: mockDiscordService, + mockAlgolia: mockAlgolia, + } + authz = drip_authorization.NewAuthorizationManager(client, impl.RegistryService). + AuthorizationMiddleware() + return +} - t.Run("Get Publisher", func(t *testing.T) { - getPublisherResponse, err := impl.GetPublisher(ctx, drip.GetPublisherRequestObject{ - PublisherId: publisherId}) - require.NoError(t, err, "should return created publisher") - assert.Equal(t, publisherId, *getPublisherResponse.(drip.GetPublisher200JSONResponse).Id) - assert.Equal(t, description, *getPublisherResponse.(drip.GetPublisher200JSONResponse).Description) - assert.Equal(t, source_code_repo, *getPublisherResponse.(drip.GetPublisher200JSONResponse).SourceCodeRepo) - assert.Equal(t, website, *getPublisherResponse.(drip.GetPublisher200JSONResponse).Website) - assert.Equal(t, support, *getPublisherResponse.(drip.GetPublisher200JSONResponse).Support) - assert.Equal(t, logo, *getPublisherResponse.(drip.GetPublisher200JSONResponse).Logo) - assert.Equal(t, name, *getPublisherResponse.(drip.GetPublisher200JSONResponse).Name) - - // Check the number of members returned - expectedMembersCount := 1 // Adjust to your expected count - assert.Equal(t, expectedMembersCount, - len(*getPublisherResponse.(drip.GetPublisher200JSONResponse).Members), - "should return the correct number of members") - - // Check specific properties of each member, adjust indices accordingly - for _, member := range *getPublisherResponse.(drip.GetPublisher200JSONResponse).Members { - expectedUserId := testUser.ID - expectedUserName := testUser.Name - expectedUserEmail := testUser.Email - - assert.Equal(t, expectedUserId, *member.User.Id, "User ID should match") - assert.Equal(t, expectedUserName, *member.User.Name, "User name should match") - assert.Equal(t, expectedUserEmail, *member.User.Email, "User email should match") - } - }) +func TestRegistryPublisher(t *testing.T) { + client, cleanup := setupDB(t, context.Background()) + defer cleanup() + impl, authz := newMockedImpl(client, &config.Config{}) - t.Run("List Publishers", func(t *testing.T) { - res, err := impl.ListPublishers(ctx, drip.ListPublishersRequestObject{}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ListPublishers200JSONResponse{}, res, "should return 200 status code") - res200 := res.(drip.ListPublishers200JSONResponse) - require.Len(t, res200, 1, "should return all stored publlishers") - assert.Equal(t, drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - - // generated thus ignored in comparison - Members: res200[0].Members, - CreatedAt: res200[0].CreatedAt, - }, res200[0], "should return correct publishers") - }) + ctx, testUser := setUpTest(client) + pub := randomPublisher() - t.Run("Get Non-Exist Publisher", func(t *testing.T) { - res, err := impl.GetPublisher(ctx, drip.GetPublisherRequestObject{PublisherId: publisherId + "invalid"}) - require.NoError(t, err, "should not return error") - assert.IsType(t, drip.GetPublisher404JSONResponse{}, res) - }) + createPublisherResponse, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: pub, + }) + require.NoError(t, err, "should return created publisher") + require.NotNil(t, createPublisherResponse, "should return created publisher") + assert.Equal(t, pub.Id, createPublisherResponse.(drip.CreatePublisher201JSONResponse).Id) + assert.Equal(t, pub.Description, createPublisherResponse.(drip.CreatePublisher201JSONResponse).Description) + assert.Equal(t, pub.SourceCodeRepo, createPublisherResponse.(drip.CreatePublisher201JSONResponse).SourceCodeRepo) + assert.Equal(t, pub.Website, createPublisherResponse.(drip.CreatePublisher201JSONResponse).Website) + assert.Equal(t, pub.Support, createPublisherResponse.(drip.CreatePublisher201JSONResponse).Support) + assert.Equal(t, pub.Logo, createPublisherResponse.(drip.CreatePublisher201JSONResponse).Logo) + + t.Run("Reject New Publisher With The Same Name", func(t *testing.T) { + res, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: pub, + }) + require.NoError(t, err, "should return error") + assert.IsType(t, drip.CreatePublisher400JSONResponse{}, res) + }) - t.Run("Update Publisher", func(t *testing.T) { - update_description := "update-test-description" - update_source_code_repo := "update-test-source-code-repo" - update_website := "update-test-website" - update_support := "update-test-support" - update_logo := "update-test-logo" - update_name := "update-test-name" - - updatePublisherResponse, err := impl.UpdatePublisher(ctx, drip.UpdatePublisherRequestObject{ - PublisherId: publisherId, - Body: &drip.Publisher{ - Description: &update_description, - SourceCodeRepo: &update_source_code_repo, - Website: &update_website, - Support: &update_support, - Logo: &update_logo, - Name: &update_name, - }, - }) - require.NoError(t, err, "should return created publisher") - assert.Equal(t, publisherId, *updatePublisherResponse.(drip.UpdatePublisher200JSONResponse).Id) - assert.Equal(t, update_description, *updatePublisherResponse.(drip.UpdatePublisher200JSONResponse).Description) - assert.Equal(t, update_source_code_repo, *updatePublisherResponse.(drip.UpdatePublisher200JSONResponse).SourceCodeRepo) - assert.Equal(t, update_website, *updatePublisherResponse.(drip.UpdatePublisher200JSONResponse).Website) - assert.Equal(t, update_support, *updatePublisherResponse.(drip.UpdatePublisher200JSONResponse).Support) - assert.Equal(t, update_logo, *updatePublisherResponse.(drip.UpdatePublisher200JSONResponse).Logo) - - _, err = impl.ListPublishersForUser(ctx, drip.ListPublishersForUserRequestObject{}) - require.NoError(t, err, "should return created publisher") + t.Run("Validate Publisher", func(t *testing.T) { + res, err := withMiddleware(authz, impl.ValidatePublisher)(ctx, drip.ValidatePublisherRequestObject{ + Params: drip.ValidatePublisherParams{Username: *pub.Name}, }) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.ValidatePublisher200JSONResponse{}, res, "should return 200") + require.True(t, *res.(drip.ValidatePublisher200JSONResponse).IsAvailable, "should be available") + }) - t.Run("Reject New Publisher With The Same Name", func(t *testing.T) { - duplicateCreatePublisherResponse, err := impl.CreatePublisher(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - }, - }) - require.NoError(t, err, "should return error") - assert.IsType(t, drip.CreatePublisher400JSONResponse{}, duplicateCreatePublisherResponse) - }) + t.Run("Get Publisher", func(t *testing.T) { + res, err := withMiddleware(authz, impl.GetPublisher)(ctx, drip.GetPublisherRequestObject{ + PublisherId: *pub.Id}) + require.NoError(t, err, "should return created publisher") + assert.Equal(t, pub.Id, res.(drip.GetPublisher200JSONResponse).Id) + assert.Equal(t, pub.Description, res.(drip.GetPublisher200JSONResponse).Description) + assert.Equal(t, pub.SourceCodeRepo, res.(drip.GetPublisher200JSONResponse).SourceCodeRepo) + assert.Equal(t, pub.Website, res.(drip.GetPublisher200JSONResponse).Website) + assert.Equal(t, pub.Support, res.(drip.GetPublisher200JSONResponse).Support) + assert.Equal(t, pub.Logo, res.(drip.GetPublisher200JSONResponse).Logo) + assert.Equal(t, pub.Name, res.(drip.GetPublisher200JSONResponse).Name) + + // Check the number of members returned + expectedMembersCount := 1 // Adjust to your expected count + assert.Equal(t, expectedMembersCount, + len(*res.(drip.GetPublisher200JSONResponse).Members), + "should return the correct number of members") + + // Check specific properties of each member, adjust indices accordingly + for _, member := range *res.(drip.GetPublisher200JSONResponse).Members { + expectedUserId := testUser.ID + expectedUserName := testUser.Name + expectedUserEmail := testUser.Email + + assert.Equal(t, expectedUserId, *member.User.Id, "User ID should match") + assert.Equal(t, expectedUserName, *member.User.Name, "User name should match") + assert.Equal(t, expectedUserEmail, *member.User.Email, "User email should match") + } + }) - t.Run("Delete Publisher", func(t *testing.T) { - res, err := impl.DeletePublisher(ctx, drip.DeletePublisherRequestObject{PublisherId: publisherId}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.DeletePublisher204Response{}, res, "should return 204") - }) + t.Run("Get Non-Exist Publisher", func(t *testing.T) { + res, err := withMiddleware(authz, impl.GetPublisher)(ctx, drip.GetPublisherRequestObject{PublisherId: *pub.Id + "invalid"}) + require.NoError(t, err, "should not return error") + assert.IsType(t, drip.GetPublisher404JSONResponse{}, res) }) - t.Run("Personal Access Token", func(t *testing.T) { - ctx, _ := setUpTest(client) - publisherId := "test-publisher-pat" - description := "test-description" - source_code_repo := "test-source-code-repo" - website := "test-website" - support := "test-support" - logo := "test-logo" - name := "test-name" - tokenName := "test-token-name" - tokenDescription := "test-token-description" - - t.Run("Create Publisher", func(t *testing.T) { - createPublisherResponse, err := impl.CreatePublisher(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, - }, - }) + t.Run("List Publishers", func(t *testing.T) { + res, err := withMiddleware(authz, impl.ListPublishers)(ctx, drip.ListPublishersRequestObject{}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.ListPublishers200JSONResponse{}, res, "should return 200 status code") + res200 := res.(drip.ListPublishers200JSONResponse) + require.Len(t, res200, 1, "should return all stored publlishers") + assert.Equal(t, drip.Publisher{ + Id: pub.Id, + Description: pub.Description, + SourceCodeRepo: pub.SourceCodeRepo, + Website: pub.Website, + Support: pub.Support, + Logo: pub.Logo, + Name: pub.Name, + + // generated thus ignored in comparison + Members: res200[0].Members, + CreatedAt: res200[0].CreatedAt, + Status: res200[0].Status, + }, res200[0], "should return correct publishers") + }) - require.NoError(t, err, "should return created publisher") - require.NotNil(t, createPublisherResponse, "should return created publisher") - }) + t.Run("Update Publisher", func(t *testing.T) { + pubUpdated := randomPublisher() + pubUpdated.Id, pubUpdated.Name = pub.Id, pub.Name + pub = pubUpdated - t.Run("List Personal Access Token Before Create", func(t *testing.T) { - none, err := impl.ListPersonalAccessTokens(ctx, drip.ListPersonalAccessTokensRequestObject{ - PublisherId: publisherId, - }) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ListPersonalAccessTokens200JSONResponse{}, none, "should return 200") - assert.Empty(t, none.(drip.ListPersonalAccessTokens200JSONResponse)) + res, err := withMiddleware(authz, impl.UpdatePublisher)(ctx, drip.UpdatePublisherRequestObject{ + PublisherId: *pubUpdated.Id, + Body: pubUpdated, }) + require.NoError(t, err, "should return created publisher") + assert.Equal(t, pubUpdated.Id, res.(drip.UpdatePublisher200JSONResponse).Id) + assert.Equal(t, pubUpdated.Description, res.(drip.UpdatePublisher200JSONResponse).Description) + assert.Equal(t, pubUpdated.SourceCodeRepo, res.(drip.UpdatePublisher200JSONResponse).SourceCodeRepo) + assert.Equal(t, pubUpdated.Website, res.(drip.UpdatePublisher200JSONResponse).Website) + assert.Equal(t, pubUpdated.Support, res.(drip.UpdatePublisher200JSONResponse).Support) + assert.Equal(t, pubUpdated.Logo, res.(drip.UpdatePublisher200JSONResponse).Logo) + + _, err = withMiddleware(authz, impl.ListPublishersForUser)(ctx, drip.ListPublishersForUserRequestObject{}) + require.NoError(t, err, "should return created publisher") + }) - t.Run("Create Personal Acccess Token", func(t *testing.T) { - createPersonalAccessTokenResponse, err := impl.CreatePersonalAccessToken( - ctx, drip.CreatePersonalAccessTokenRequestObject{ - PublisherId: publisherId, - Body: &drip.PersonalAccessToken{ - Name: &tokenName, - Description: &tokenDescription, - }, - }) - require.NoError(t, err, "should return created token") - require.NotNil(t, - *createPersonalAccessTokenResponse.(drip.CreatePersonalAccessToken201JSONResponse).Token, - "Token should have a value.") - }) + t.Run("Delete Publisher", func(t *testing.T) { + res, err := withMiddleware(authz, impl.DeletePublisher)(ctx, drip.DeletePublisherRequestObject{PublisherId: *pub.Id}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.DeletePublisher204Response{}, res, "should return 204") + }) +} - t.Run("List Personal Access Token", func(t *testing.T) { - getPersonalAccessTokenResponse, err := impl.ListPersonalAccessTokens(ctx, drip.ListPersonalAccessTokensRequestObject{ - PublisherId: publisherId, - }) - require.NoError(t, err, "should return created token") - assert.Equal(t, tokenName, - *getPersonalAccessTokenResponse.(drip.ListPersonalAccessTokens200JSONResponse)[0].Name) - assert.Equal(t, tokenDescription, - *getPersonalAccessTokenResponse.(drip.ListPersonalAccessTokens200JSONResponse)[0].Description) - assert.True(t, - isTokenMasked(*getPersonalAccessTokenResponse.(drip.ListPersonalAccessTokens200JSONResponse)[0].Token)) - }) +func TestRegistryPersonalAccessToken(t *testing.T) { + client, cleanup := setupDB(t, context.Background()) + defer cleanup() + impl, authz := newMockedImpl(client, &config.Config{}) + + ctx, _ := setUpTest(client) + pub := randomPublisher() + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: pub, }) + require.NoError(t, err, "should return created publisher") - t.Run("Node", func(t *testing.T) { - ctx, _ := setUpTest(client) - publisherId := "test-publisher-node" - description := "test-description" - sourceCodeRepo := "test-source-code-repo" - website := "test-website" - support := "test-support" - logo := "test-logo" - name := "test-name" - - createPublisherResponse, err := impl.CreatePublisher(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &sourceCodeRepo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, + tokenName := "test-token-name" + tokenDescription := "test-token-description" + res, err := withMiddleware(authz, impl.CreatePersonalAccessToken)( + ctx, drip.CreatePersonalAccessTokenRequestObject{ + PublisherId: *pub.Id, + Body: &drip.PersonalAccessToken{ + Name: &tokenName, + Description: &tokenDescription, }, }) - require.NoError(t, err, "should return created publisher") - require.NotNil(t, createPublisherResponse, "should return created publisher") - - nodeId := "test-node" - nodeDescription := "test-node-description" - nodeAuthor := "test-node-author" - nodeLicense := "test-node-license" - nodeName := "test-node-name" - nodeTags := []string{"test-node-tag"} - icon := "https://wwww.github.com/test-icon.svg" - githubUrl := "https://www.github.com/test-github-url" - - var real_node_id *string - t.Run("Create Node", func(t *testing.T) { - createNodeResponse, err := impl.CreateNode(ctx, drip.CreateNodeRequestObject{ - PublisherId: publisherId, - Body: &drip.Node{ - Id: &nodeId, - Name: &nodeName, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Tags: &nodeTags, - Icon: &icon, - Repository: &githubUrl, - }, - }) - require.NoError(t, err, "should return created node") - require.NotNil(t, createNodeResponse, "should return created node") - assert.Equal(t, nodeId, *createNodeResponse.(drip.CreateNode201JSONResponse).Id) - assert.Equal(t, nodeDescription, *createNodeResponse.(drip.CreateNode201JSONResponse).Description) - assert.Equal(t, nodeAuthor, *createNodeResponse.(drip.CreateNode201JSONResponse).Author) - assert.Equal(t, nodeLicense, *createNodeResponse.(drip.CreateNode201JSONResponse).License) - assert.Equal(t, nodeName, *createNodeResponse.(drip.CreateNode201JSONResponse).Name) - assert.Equal(t, nodeTags, *createNodeResponse.(drip.CreateNode201JSONResponse).Tags) - assert.Equal(t, icon, *createNodeResponse.(drip.CreateNode201JSONResponse).Icon) - assert.Equal(t, githubUrl, *createNodeResponse.(drip.CreateNode201JSONResponse).Repository) - real_node_id = createNodeResponse.(drip.CreateNode201JSONResponse).Id + require.NoError(t, err, "should return created token") + require.NotNil(t, + *res.(drip.CreatePersonalAccessToken201JSONResponse).Token, + "Token should have a value.") + t.Run("List Personal Access Token", func(t *testing.T) { + getPersonalAccessTokenResponse, err := withMiddleware(authz, impl.ListPersonalAccessTokens)(ctx, drip.ListPersonalAccessTokensRequestObject{ + PublisherId: *pub.Id, }) + require.NoError(t, err, "should return created token") + assert.Equal(t, tokenName, + *getPersonalAccessTokenResponse.(drip.ListPersonalAccessTokens200JSONResponse)[0].Name) + assert.Equal(t, tokenDescription, + *getPersonalAccessTokenResponse.(drip.ListPersonalAccessTokens200JSONResponse)[0].Description) + assert.True(t, + isTokenMasked(*getPersonalAccessTokenResponse.(drip.ListPersonalAccessTokens200JSONResponse)[0].Token)) + }) +} - t.Run("Get Node", func(t *testing.T) { - res, err := impl.GetNode(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.GetNode200JSONResponse{}, res) - res200 := res.(drip.GetNode200JSONResponse) - assert.Equal(t, drip.GetNode200JSONResponse{ - Id: &nodeId, - Name: &nodeName, - Description: &nodeDescription, - Author: &nodeAuthor, - Tags: &nodeTags, - License: &nodeLicense, - Icon: &icon, - Repository: &githubUrl, - }, res200, "should return stored node data") - }) +func TestRegistryNode(t *testing.T) { + client, cleanup := setupDB(t, context.Background()) + defer cleanup() + impl, authz := newMockedImpl(client, &config.Config{}) - t.Run("Get Not Exist Node", func(t *testing.T) { - res, err := impl.GetNode(ctx, drip.GetNodeRequestObject{NodeId: nodeId + "fake"}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.GetNode404JSONResponse{}, res) - }) + ctx, _ := setUpTest(client) + pub := randomPublisher() - t.Run("Get Publisher Nodes", func(t *testing.T) { - res, err := impl.ListNodesForPublisher(ctx, drip.ListNodesForPublisherRequestObject{ - PublisherId: publisherId, - }) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ListNodesForPublisher200JSONResponse{}, res) - res200 := res.(drip.ListNodesForPublisher200JSONResponse) - require.Len(t, res200, 1) - assert.Equal(t, drip.Node{ - Id: &nodeId, - Name: &nodeName, - Description: &nodeDescription, - Author: &nodeAuthor, - Tags: &nodeTags, - License: &nodeLicense, - Icon: &icon, - Repository: &githubUrl, - }, res200[0], "should return stored node data") - }) + _, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: pub, + }) + require.NoError(t, err, "should return created publisher") - t.Run("Update Node", func(t *testing.T) { - updateNodeDescription := "update_test-node-description" - updateNodeAuthor := "update_test-node-author" - updateNodeLicense := "update_test-node-license" - updateNodeName := "update_test-node-name" - updateNodeTags := []string{"update-test-node-tag"} - updateIcon := "https://wwww.github.com/update-icon.svg" - updateGithubUrl := "https://www.github.com/update-github-url" - - updateNodeResponse, err := impl.UpdateNode(ctx, drip.UpdateNodeRequestObject{ - PublisherId: publisherId, - NodeId: *real_node_id, - Body: &drip.Node{ - Id: &nodeId, - Description: &updateNodeDescription, - Author: &updateNodeAuthor, - License: &updateNodeLicense, - Name: &updateNodeName, - Tags: &updateNodeTags, - Icon: &updateIcon, - Repository: &updateGithubUrl, - }, - }) - require.NoError(t, err, "should return created node") - assert.Equal(t, nodeId, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Id) - assert.Equal(t, updateNodeDescription, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Description) - assert.Equal(t, updateNodeAuthor, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Author) - assert.Equal(t, updateNodeLicense, *updateNodeResponse.(drip.UpdateNode200JSONResponse).License) - assert.Equal(t, updateNodeName, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Name) - assert.Equal(t, updateNodeTags, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Tags) - assert.Equal(t, updateIcon, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Icon) - assert.Equal(t, updateGithubUrl, *updateNodeResponse.(drip.UpdateNode200JSONResponse).Repository) - - resUpdated, err := impl.GetNode(ctx, drip.GetNodeRequestObject{NodeId: nodeId}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.GetNode200JSONResponse{}, resUpdated) - res200Updated := resUpdated.(drip.GetNode200JSONResponse) - assert.Equal(t, drip.GetNode200JSONResponse{ - Id: &nodeId, - Description: &updateNodeDescription, - Author: &updateNodeAuthor, - License: &updateNodeLicense, - Name: &updateNodeName, - Tags: &updateNodeTags, - Icon: &updateIcon, - Repository: &updateGithubUrl, - }, res200Updated, "should return updated node data") - }) + node := randomNode() + res, err := withMiddleware(authz, impl.CreateNode)(ctx, drip.CreateNodeRequestObject{ + PublisherId: *pub.Id, + Body: node, + }) + require.NoError(t, err, "should return created node") + require.NotNil(t, res, "should return created node") + assert.Equal(t, node.Id, res.(drip.CreateNode201JSONResponse).Id) + assert.Equal(t, node.Description, res.(drip.CreateNode201JSONResponse).Description) + assert.Equal(t, node.Author, res.(drip.CreateNode201JSONResponse).Author) + assert.Equal(t, node.License, res.(drip.CreateNode201JSONResponse).License) + assert.Equal(t, node.Name, res.(drip.CreateNode201JSONResponse).Name) + assert.Equal(t, node.Tags, res.(drip.CreateNode201JSONResponse).Tags) + assert.Equal(t, node.Icon, res.(drip.CreateNode201JSONResponse).Icon) + assert.Equal(t, node.Repository, res.(drip.CreateNode201JSONResponse).Repository) + assert.Equal(t, drip.NodeStatusActive, *res.(drip.CreateNode201JSONResponse).Status) + + t.Run("Get Node", func(t *testing.T) { + res, err := withMiddleware(authz, impl.GetNode)(ctx, drip.GetNodeRequestObject{NodeId: *node.Id}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.GetNode200JSONResponse{}, res) + res200 := res.(drip.GetNode200JSONResponse) + expDl, expRate := 0, float32(0) + nodeStatus := drip.NodeStatusActive + assert.Equal(t, drip.GetNode200JSONResponse{ + Id: node.Id, + Name: node.Name, + Description: node.Description, + Author: node.Author, + Tags: node.Tags, + License: node.License, + Icon: node.Icon, + Repository: node.Repository, + + Downloads: &expDl, + Rating: &expRate, + Status: &nodeStatus, + StatusDetail: proto.String(""), + Category: proto.String(""), + }, res200, "should return stored node data") + }) - t.Run("Update Not Exist Node", func(t *testing.T) { - res, err := impl.UpdateNode(ctx, drip.UpdateNodeRequestObject{PublisherId: publisherId, NodeId: nodeId + "fake"}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.UpdateNode404JSONResponse{}, res) - }) + t.Run("Get Publisher Nodes", func(t *testing.T) { + res, err := withMiddleware(authz, impl.ListNodesForPublisher)(ctx, drip.ListNodesForPublisherRequestObject{ + PublisherId: *pub.Id, + }) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.ListNodesForPublisher200JSONResponse{}, res) + res200 := res.(drip.ListNodesForPublisher200JSONResponse) + require.Len(t, res200, 1) + expDl, expRate := 0, float32(0) + nodeStatus := drip.NodeStatusActive + assert.Equal(t, drip.Node{ + Id: node.Id, + Name: node.Name, + Description: node.Description, + Author: node.Author, + Tags: node.Tags, + License: node.License, + Icon: node.Icon, + Repository: node.Repository, + + Downloads: &expDl, + Rating: &expRate, + Status: &nodeStatus, + StatusDetail: proto.String(""), + Category: proto.String(""), + }, res200[0], "should return stored node data") + }) - t.Run("Delete Node", func(t *testing.T) { - res, err := impl.DeleteNode(ctx, drip.DeleteNodeRequestObject{PublisherId: publisherId, NodeId: nodeId}) - require.NoError(t, err, "should not return error") - assert.IsType(t, drip.DeleteNode204Response{}, res) - }) + t.Run("Get Not Exist Node", func(t *testing.T) { + res, err := withMiddleware(authz, impl.GetNode)(ctx, drip.GetNodeRequestObject{NodeId: *node.Id + "fake"}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.GetNode404JSONResponse{}, res) + }) + + t.Run("Update Node", func(t *testing.T) { + unode := randomNode() + unode.Id = node.Id + node = unode + + updateNodeResponse, err := withMiddleware(authz, impl.UpdateNode)(ctx, drip.UpdateNodeRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, + Body: node, + }) + require.NoError(t, err, "should return created node") + assert.Equal(t, node.Id, updateNodeResponse.(drip.UpdateNode200JSONResponse).Id) + assert.Equal(t, node.Description, updateNodeResponse.(drip.UpdateNode200JSONResponse).Description) + assert.Equal(t, node.Author, updateNodeResponse.(drip.UpdateNode200JSONResponse).Author) + assert.Equal(t, node.License, updateNodeResponse.(drip.UpdateNode200JSONResponse).License) + assert.Equal(t, node.Name, updateNodeResponse.(drip.UpdateNode200JSONResponse).Name) + assert.Equal(t, node.Tags, updateNodeResponse.(drip.UpdateNode200JSONResponse).Tags) + assert.Equal(t, node.Icon, updateNodeResponse.(drip.UpdateNode200JSONResponse).Icon) + assert.Equal(t, node.Repository, updateNodeResponse.(drip.UpdateNode200JSONResponse).Repository) + + resUpdated, err := withMiddleware(authz, impl.GetNode)(ctx, drip.GetNodeRequestObject{NodeId: *node.Id}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.GetNode200JSONResponse{}, resUpdated) + res200Updated := resUpdated.(drip.GetNode200JSONResponse) + expDl, expRate := 0, float32(0) + nodeStatus := drip.NodeStatusActive + assert.Equal(t, drip.GetNode200JSONResponse{ + Id: node.Id, + Description: node.Description, + Author: node.Author, + License: node.License, + Name: node.Name, + Tags: node.Tags, + Icon: node.Icon, + Repository: node.Repository, + + Downloads: &expDl, + Rating: &expRate, + Status: &nodeStatus, + StatusDetail: proto.String(""), + Category: proto.String(""), + }, res200Updated, "should return updated node data") + }) + + t.Run("Update Not Exist Node", func(t *testing.T) { + res, err := withMiddleware(authz, impl.UpdateNode)(ctx, drip.UpdateNodeRequestObject{PublisherId: *pub.Id, NodeId: *node.Id + "fake", Body: &drip.Node{}}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.UpdateNode404JSONResponse{}, res) + }) + + t.Run("Index Nodes", func(t *testing.T) { + res, err := withMiddleware(authz, impl.ReindexNodes)(ctx, drip.ReindexNodesRequestObject{}) + require.NoError(t, err, "should not return error") + assert.IsType(t, drip.ReindexNodes200Response{}, res) + }) + + t.Run("Delete Node", func(t *testing.T) { + res, err := withMiddleware(authz, impl.DeleteNode)(ctx, drip.DeleteNodeRequestObject{PublisherId: *pub.Id, NodeId: *node.Id}) + require.NoError(t, err, "should not return error") + assert.IsType(t, drip.DeleteNode204Response{}, res) }) - t.Run("Node Version", func(t *testing.T) { - ctx, _ := setUpTest(client) - publisherId := "test-publisher-node-version" - description := "test-description" - source_code_repo := "test-source-code-repo" - website := "test-website" - support := "test-support" - logo := "test-logo" - name := "test-name" - - createPublisherResponse, err := impl.CreatePublisher(ctx, drip.CreatePublisherRequestObject{ - Body: &drip.Publisher{ - Id: &publisherId, - Description: &description, - SourceCodeRepo: &source_code_repo, - Website: &website, - Support: &support, - Logo: &logo, - Name: &name, + t.Run("Delete Not Exist Node", func(t *testing.T) { + res, err := withMiddleware(authz, impl.DeleteNode)(ctx, drip.DeleteNodeRequestObject{PublisherId: *pub.Id, NodeId: *node.Id + "fake"}) + require.NoError(t, err, "should not return error") + assert.IsType(t, drip.DeleteNode204Response{}, res) + }) +} + +func TestRegistryNodeVersion(t *testing.T) { + client, cleanup := setupDB(t, context.Background()) + defer cleanup() + impl, authz := newMockedImpl(client, &config.Config{}) + + ctx, _ := setUpTest(client) + pub := randomPublisher() + + respub, err := withMiddleware(authz, impl.CreatePublisher)(ctx, drip.CreatePublisherRequestObject{ + Body: pub, + }) + require.NoError(t, err, "should return created publisher") + createdPublisher := (respub.(drip.CreatePublisher201JSONResponse)) + + tokenName := "test-token-name" + tokenDescription := "test-token-description" + respat, err := withMiddleware(authz, impl.CreatePersonalAccessToken)(ctx, drip.CreatePersonalAccessTokenRequestObject{ + PublisherId: *pub.Id, + Body: &drip.PersonalAccessToken{ + Name: &tokenName, + Description: &tokenDescription, + }, + }) + require.NoError(t, err, "should return created token") + token := *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token + + node := randomNode() + nodeVersion := randomNodeVersion(0) + signedUrl := "test-url" + downloadUrl := fmt.Sprintf("https://storage.googleapis.com/comfy-registry/%s/%s/%s/node.tar.gz", *pub.Id, *node.Id, *nodeVersion.Version) + var createdNodeVersion drip.NodeVersion + + impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return(signedUrl, nil) + impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return(signedUrl, nil) + impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil) + res, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, + Body: &drip.PublishNodeVersionJSONRequestBody{ + PersonalAccessToken: token, + Node: *node, + NodeVersion: *nodeVersion, + }, + }) + require.NoError(t, err, "should return created node version") + assert.Equal(t, nodeVersion.Version, res.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Version) + require.Equal(t, nodeVersion.Dependencies, res.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Dependencies, "should return pip dependencies") + require.Equal(t, nodeVersion.Changelog, res.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Changelog, "should return changelog") + require.Equal(t, signedUrl, *res.(drip.PublishNodeVersion201JSONResponse).SignedUrl, "should return signed url") + versionStatus := drip.NodeVersionStatusPending + require.Equal(t, versionStatus, *res.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Status, "should return pending status") + createdNodeVersion = *res.(drip.PublishNodeVersion201JSONResponse).NodeVersion // Needed for downstream tests. + + t.Run("Admin Update", func(t *testing.T) { + adminCtx, _ := setUpAdminTest(client) + activeStatus := drip.NodeVersionStatusActive + adminUpdateNodeVersionResp, err := impl.AdminUpdateNodeVersion(adminCtx, drip.AdminUpdateNodeVersionRequestObject{ + NodeId: *node.Id, + VersionNumber: *createdNodeVersion.Version, + Body: &drip.AdminUpdateNodeVersionJSONRequestBody{ + Status: &activeStatus, }, }) - require.NoError(t, err, "should return created publisher") - require.NotNil(t, createPublisherResponse, "should return created publisher") - assert.Equal(t, publisherId, *createPublisherResponse.(drip.CreatePublisher201JSONResponse).Id) + require.NoError(t, err, "should return updated node version") + assert.Equal(t, activeStatus, *adminUpdateNodeVersionResp.(drip.AdminUpdateNodeVersion200JSONResponse).Status) + }) - tokenName := "test-token-name" - tokenDescription := "test-token-description" - createPersonalAccessTokenResponse, err := impl.CreatePersonalAccessToken(ctx, drip.CreatePersonalAccessTokenRequestObject{ - PublisherId: publisherId, - Body: &drip.PersonalAccessToken{ - Name: &tokenName, - Description: &tokenDescription, + t.Run("List Node Version Before Create", func(t *testing.T) { + node := randomNode() + resVersions, err := withMiddleware(authz, impl.ListNodeVersions)(ctx, drip.ListNodeVersionsRequestObject{NodeId: *node.Id}) + require.NoError(t, err, "should return error since node version doesn't exists") + require.IsType(t, drip.ListNodeVersions200JSONResponse{}, resVersions) + assert.Empty(t, resVersions.(drip.ListNodeVersions200JSONResponse), "should not return any node versions") + }) + + t.Run("Create Node Version with Fake Token", func(t *testing.T) { + _, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, + Body: &drip.PublishNodeVersionJSONRequestBody{ + Node: *node, + NodeVersion: *nodeVersion, + PersonalAccessToken: "faketoken", }, }) - require.NoError(t, err, "should return created token") - require.NotNil(t, *createPersonalAccessTokenResponse.(drip.CreatePersonalAccessToken201JSONResponse).Token, "Token should have a value.") - - nodeId := "test-node1" - nodeDescription := "test-node-description" - nodeAuthor := "test-node-author" - nodeLicense := "test-node-license" - nodeName := "test-node-name" - nodeTags := []string{"test-node-tag"} - nodeVersionLiteral := "1.0.0" - changelog := "test-changelog" - dependencies := []string{"test-dependency"} - downloadUrl := "https://storage.googleapis.com/comfy-registry/test-publisher-node-version/test-node1/1.0.0/node.tar.gz" - - createdPublisher := createPublisherResponse.(drip.CreatePublisher201JSONResponse) - var createdNodeVersion drip.NodeVersion - - t.Run("List Node Version Before Create", func(t *testing.T) { - resVersions, err := impl.ListNodeVersions(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) - require.NoError(t, err, "should return error since node version doesn't exists") - require.IsType(t, drip.ListNodeVersions200JSONResponse{}, resVersions) - assert.Empty(t, resVersions.(drip.ListNodeVersions200JSONResponse), "should not return any node versions") - }) + require.Error(t, err) + assert.Equal(t, http.StatusBadRequest, err.(*echo.HTTPError).Code, "should return 400 bad request") + }) - t.Run("Create Node Version with Fake Token", func(t *testing.T) { - response, err := impl.PublishNodeVersion(ctx, drip.PublishNodeVersionRequestObject{ - PublisherId: publisherId, - NodeId: nodeId, + t.Run("Create Node Version with invalid node id", func(t *testing.T) { + for _, suffix := range []string{"LOWERCASEONLY", "invalidCharacter&"} { + node := randomNode() + *node.Id = *node.Id + suffix + res, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, Body: &drip.PublishNodeVersionJSONRequestBody{ - Node: drip.Node{ - Id: &nodeId, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Name: &nodeName, - Tags: &nodeTags, - Repository: &source_code_repo, - }, - NodeVersion: drip.NodeVersion{ - Version: &nodeVersionLiteral, - Changelog: &changelog, - Dependencies: &dependencies, - }, - PersonalAccessToken: "faketoken", + Node: *node, + NodeVersion: *randomNodeVersion(0), + PersonalAccessToken: token, }, }) require.NoError(t, err) - assert.Equal(t, "Invalid personal access token", response.(drip.PublishNodeVersion400JSONResponse).Message, "should return error message") - }) + require.IsType(t, drip.PublishNodeVersion400JSONResponse{}, res) + } + }) - t.Run("Create Node Version", func(t *testing.T) { - mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return("test-url", nil) - mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return("test-url", nil) - createNodeVersionResp, err := impl.PublishNodeVersion(ctx, drip.PublishNodeVersionRequestObject{ - PublisherId: publisherId, - NodeId: nodeId, - Body: &drip.PublishNodeVersionJSONRequestBody{ - Node: drip.Node{ - Id: &nodeId, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Name: &nodeName, - Tags: &nodeTags, - Repository: &source_code_repo, - }, - NodeVersion: drip.NodeVersion{ - Version: &nodeVersionLiteral, - Changelog: &changelog, - Dependencies: &dependencies, - }, - PersonalAccessToken: *createPersonalAccessTokenResponse.(drip.CreatePersonalAccessToken201JSONResponse).Token, - }, - }) - require.NoError(t, err, "should return created node version") - assert.Equal(t, nodeVersionLiteral, *createNodeVersionResp.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Version) - require.Equal(t, "test-url", *createNodeVersionResp.(drip.PublishNodeVersion201JSONResponse).SignedUrl, "should return signed url") - require.Equal(t, dependencies, *createNodeVersionResp.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Dependencies, "should return pip dependencies") - require.Equal(t, changelog, *createNodeVersionResp.(drip.PublishNodeVersion201JSONResponse).NodeVersion.Changelog, "should return changelog") - createdNodeVersion = *createNodeVersionResp.(drip.PublishNodeVersion201JSONResponse).NodeVersion - }) + t.Run("Get not exist Node Version ", func(t *testing.T) { + res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{NodeId: *node.Id + "fake", VersionId: *nodeVersion.Version}) + require.NoError(t, err, "should not return error") + assert.IsType(t, drip.GetNodeVersion404JSONResponse{}, res) + }) - t.Run("Get not exist Node Version ", func(t *testing.T) { - res, err := impl.GetNodeVersion(ctx, drip.GetNodeVersionRequestObject{NodeId: nodeId + "fake", VersionId: nodeVersionLiteral}) - require.NoError(t, err, "should not return error") - assert.IsType(t, drip.GetNodeVersion404JSONResponse{}, res) + t.Run("Create Node Version of Not Exist Node", func(t *testing.T) { + _, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id + "fake", + Body: &drip.PublishNodeVersionJSONRequestBody{}, }) + require.Error(t, err) + assert.Equal(t, http.StatusBadRequest, err.(*echo.HTTPError).Code, "should return 400 bad request") + }) - t.Run("Create Node Version of Not Exist Node", func(t *testing.T) { - response, err := impl.PublishNodeVersion(ctx, drip.PublishNodeVersionRequestObject{ - PublisherId: publisherId, - NodeId: nodeId + "fake", - Body: &drip.PublishNodeVersionJSONRequestBody{}, - }) - require.NoError(t, err) - assert.Equal(t, "Invalid personal access token", response.(drip.PublishNodeVersion400JSONResponse).Message, "should return error message") - }) + t.Run("List Node Versions", func(t *testing.T) { + resVersions, err := withMiddleware(authz, impl.ListNodeVersions)(ctx, drip.ListNodeVersionsRequestObject{NodeId: *node.Id}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.ListNodeVersions200JSONResponse{}, resVersions, "should return 200") + resVersions200 := resVersions.(drip.ListNodeVersions200JSONResponse) + require.Len(t, resVersions200, 1, "should return only one version") + nodeVersionStatus := drip.NodeVersionStatusActive + t.Log("Download URL: ", *resVersions200[0].DownloadUrl) + t.Log("Download URL: ", downloadUrl) + t.Log("Status: ", *resVersions200[0].Status) + t.Log("Status: ", nodeVersionStatus) + assert.Equal(t, drip.NodeVersion{ + // generated attribute + Id: resVersions200[0].Id, + CreatedAt: resVersions200[0].CreatedAt, + + Deprecated: proto.Bool(false), + Version: nodeVersion.Version, + Changelog: nodeVersion.Changelog, + Dependencies: nodeVersion.Dependencies, + DownloadUrl: &downloadUrl, + Status: &nodeVersionStatus, + StatusReason: proto.String(""), + }, resVersions200[0], "should be equal") + }) - t.Run("List Node Versions", func(t *testing.T) { - resVersions, err := impl.ListNodeVersions(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ListNodeVersions200JSONResponse{}, resVersions, "should return 200") - resVersions200 := resVersions.(drip.ListNodeVersions200JSONResponse) - require.Len(t, resVersions200, 1, "should return only one version") - assert.Equal(t, drip.NodeVersion{ - // generated attribute - Id: resVersions200[0].Id, - CreatedAt: resVersions200[0].CreatedAt, - - Deprecated: proto.Bool(false), - Version: &nodeVersionLiteral, - Changelog: &changelog, - Dependencies: &dependencies, - DownloadUrl: &downloadUrl, - }, resVersions200[0], "should be equal") + t.Run("Update Node Version", func(t *testing.T) { + updatedChangelog := "test-changelog-2" + resUNV, err := withMiddleware(authz, impl.UpdateNodeVersion)(ctx, drip.UpdateNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, + VersionId: *createdNodeVersion.Id, + Body: &drip.NodeVersionUpdateRequest{ + Changelog: &updatedChangelog, + Deprecated: proto.Bool(true), + }, }) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.UpdateNodeVersion200JSONResponse{}, resUNV, "should return 200") + + res, err := withMiddleware(authz, impl.ListNodeVersions)(ctx, drip.ListNodeVersionsRequestObject{NodeId: *node.Id}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.ListNodeVersions200JSONResponse{}, res, "should return 200") + res200 := res.(drip.ListNodeVersions200JSONResponse) + require.Len(t, res200, 1, "should return only one version") + status := drip.NodeVersionStatusActive + updatedNodeVersion := drip.NodeVersion{ + // generated attribute + Id: res200[0].Id, + CreatedAt: res200[0].CreatedAt, + + Deprecated: proto.Bool(true), + Version: nodeVersion.Version, + Dependencies: nodeVersion.Dependencies, + Changelog: &updatedChangelog, + DownloadUrl: &downloadUrl, + Status: &status, + StatusReason: proto.String(""), + } + assert.Equal(t, updatedNodeVersion, res200[0], "should be equal") + createdNodeVersion = res200[0] + }) - t.Run("Update Node Version", func(t *testing.T) { - updatedChangelog := "test-changelog-2" - resUNV, err := impl.UpdateNodeVersion(ctx, drip.UpdateNodeVersionRequestObject{ - PublisherId: publisherId, - NodeId: nodeId, - VersionId: *createdNodeVersion.Id, - Body: &drip.NodeVersionUpdateRequest{ - Changelog: &updatedChangelog, - Deprecated: proto.Bool(true), - }, - }) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.UpdateNodeVersion200JSONResponse{}, resUNV, "should return 200") + t.Run("List Nodes", func(t *testing.T) { + nodeIDs := map[string]*drip.NodeVersion{ + *node.Id: &createdNodeVersion, + *node.Id + "-1": nil, + *node.Id + "-2": nil, + } - res, err := impl.ListNodeVersions(ctx, drip.ListNodeVersionsRequestObject{NodeId: nodeId}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ListNodeVersions200JSONResponse{}, res, "should return 200") - res200 := res.(drip.ListNodeVersions200JSONResponse) - require.Len(t, res200, 1, "should return only one version") - updatedNodeVersion := drip.NodeVersion{ - // generated attribute - Id: res200[0].Id, - CreatedAt: res200[0].CreatedAt, - - Deprecated: proto.Bool(true), - Version: &nodeVersionLiteral, - Changelog: &updatedChangelog, - Dependencies: &dependencies, - DownloadUrl: &downloadUrl, + for nodeId := range nodeIDs { + for i := 0; i < 2; i++ { + version := fmt.Sprintf("2.0.%d", i) + res, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: nodeId, + Body: &drip.PublishNodeVersionJSONRequestBody{ + Node: drip.Node{ + Id: &nodeId, + Description: node.Description, + Author: node.Author, + License: node.License, + Name: node.Name, + Tags: node.Tags, + Repository: node.Repository, + }, + NodeVersion: drip.NodeVersion{ + Version: &version, + Changelog: createdNodeVersion.Changelog, + Dependencies: createdNodeVersion.Dependencies, + }, + PersonalAccessToken: *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token, + }, + }) + require.NoError(t, err, "should return created node version") + require.IsType(t, drip.PublishNodeVersion201JSONResponse{}, res) + res200 := res.(drip.PublishNodeVersion201JSONResponse) + nodeIDs[nodeId] = res200.NodeVersion } - assert.Equal(t, updatedNodeVersion, res200[0], "should be equal") - createdNodeVersion = res200[0] - }) + } - t.Run("List Nodes", func(t *testing.T) { - resNodes, err := impl.ListAllNodes(ctx, drip.ListAllNodesRequestObject{}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.ListAllNodes200JSONResponse{}, resNodes, "should return 200 server response") - resNodes200 := resNodes.(drip.ListAllNodes200JSONResponse) - assert.Len(t, *resNodes200.Nodes, 1, "should only contain 1 node") + resNodes, err := withMiddleware(authz, impl.ListAllNodes)(ctx, drip.ListAllNodesRequestObject{}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.ListAllNodes200JSONResponse{}, resNodes, "should return 200 server response") + resNodes200 := resNodes.(drip.ListAllNodes200JSONResponse) + assert.Len(t, *resNodes200.Nodes, len(nodeIDs), "should only contain 1 node") + for _, node := range *resNodes200.Nodes { + expDl, expRate := 0, float32(0) + nodeStatus := drip.NodeStatusActive expectedNode := drip.Node{ - Id: &nodeId, - Name: &nodeName, - Repository: &source_code_repo, - Description: &nodeDescription, - Author: &nodeAuthor, - License: &nodeLicense, - Tags: &nodeTags, - LatestVersion: &createdNodeVersion, - Icon: proto.String(""), + Id: node.Id, + Name: node.Name, + Repository: node.Repository, + Description: node.Description, + Author: node.Author, + License: node.License, + Tags: node.Tags, + LatestVersion: nodeIDs[*node.Id], + Icon: node.Icon, Publisher: (*drip.Publisher)(&createdPublisher), + Downloads: &expDl, + Rating: &expRate, + Status: &nodeStatus, + StatusDetail: proto.String(""), + Category: proto.String(""), } - expectedNode.LatestVersion.DownloadUrl = (*resNodes200.Nodes)[0].LatestVersion.DownloadUrl // generated - expectedNode.LatestVersion.Deprecated = (*resNodes200.Nodes)[0].LatestVersion.Deprecated // generated - expectedNode.Publisher.CreatedAt = (*resNodes200.Nodes)[0].Publisher.CreatedAt - assert.Equal(t, expectedNode, (*resNodes200.Nodes)[0]) - }) + expectedNode.LatestVersion.DownloadUrl = node.LatestVersion.DownloadUrl // generated + expectedNode.LatestVersion.Deprecated = node.LatestVersion.Deprecated // generated + expectedNode.LatestVersion.CreatedAt = node.LatestVersion.CreatedAt // generated + expectedNode.Publisher.CreatedAt = node.Publisher.CreatedAt + assert.Equal(t, expectedNode, node) + } + }) - t.Run("Install Node", func(t *testing.T) { - resIns, err := impl.InstallNode(ctx, drip.InstallNodeRequestObject{NodeId: nodeId}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.InstallNode200JSONResponse{}, resIns, "should return 200") + t.Run("Node Installation", func(t *testing.T) { + resIns, err := withMiddleware(authz, impl.InstallNode)(ctx, drip.InstallNodeRequestObject{NodeId: *node.Id}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.InstallNode200JSONResponse{}, resIns, "should return 200") - resIns, err = impl.InstallNode(ctx, drip.InstallNodeRequestObject{ - NodeId: nodeId, Params: drip.InstallNodeParams{Version: &nodeVersionLiteral}}) + resIns, err = withMiddleware(authz, impl.InstallNode)(ctx, drip.InstallNodeRequestObject{ + NodeId: *node.Id, Params: drip.InstallNodeParams{Version: createdNodeVersion.Version}}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.InstallNode200JSONResponse{}, resIns, "should return 200") + + t.Run("Get Total Install", func(t *testing.T) { + res, err := withMiddleware(authz, impl.GetNode)(ctx, drip.GetNodeRequestObject{ + NodeId: *node.Id, + }) require.NoError(t, err, "should not return error") - require.IsType(t, drip.InstallNode200JSONResponse{}, resIns, "should return 200") + require.IsType(t, drip.GetNode200JSONResponse{}, res) + assert.Equal(t, int(2), *res.(drip.GetNode200JSONResponse).Downloads) }) - t.Run("Install Node Version on not exist node or version", func(t *testing.T) { - resIns, err := impl.InstallNode(ctx, drip.InstallNodeRequestObject{NodeId: nodeId + "fake"}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.InstallNode404JSONResponse{}, resIns, "should return 404") - resIns, err = impl.InstallNode(ctx, drip.InstallNodeRequestObject{ - NodeId: nodeId, Params: drip.InstallNodeParams{Version: proto.String(nodeVersionLiteral + "fake")}}) - require.NoError(t, err, "should not return error") - require.IsType(t, drip.InstallNode404JSONResponse{}, resIns, "should return 404") + t.Run("Add review", func(t *testing.T) { + res, err := withMiddleware(authz, impl.PostNodeReview)(ctx, drip.PostNodeReviewRequestObject{ + NodeId: *node.Id, + Params: drip.PostNodeReviewParams{Star: 5}, + }) + require.NoError(t, err) + require.IsType(t, drip.PostNodeReview200JSONResponse{}, res) + res200 := res.(drip.PostNodeReview200JSONResponse) + assert.Equal(t, float32(5), *res200.Rating) }) }) + + t.Run("Node installation on not exist node or version", func(t *testing.T) { + resIns, err := withMiddleware(authz, impl.InstallNode)(ctx, drip.InstallNodeRequestObject{NodeId: *node.Id + "fake"}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.InstallNode404JSONResponse{}, resIns, "should return 404") + + resIns, err = withMiddleware(authz, impl.InstallNode)(ctx, drip.InstallNodeRequestObject{ + NodeId: *node.Id, Params: drip.InstallNodeParams{Version: proto.String(*createdNodeVersion.Version + "fake")}}) + require.NoError(t, err, "should not return error") + require.IsType(t, drip.InstallNode404JSONResponse{}, resIns, "should return 404") + }) + + t.Run("Scan Node", func(t *testing.T) { + node := randomNode() + nodeVersion := randomNodeVersion(0) + downloadUrl := fmt.Sprintf("https://storage.googleapis.com/comfy-registry/%s/%s/%s/node.tar.gz", *pub.Id, *node.Id, *nodeVersion.Version) + + impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return("test-url", nil) + impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return("test-url", nil) + impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil) + _, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, + Body: &drip.PublishNodeVersionJSONRequestBody{ + Node: *node, + NodeVersion: *nodeVersion, + PersonalAccessToken: *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token, + }, + }) + require.NoError(t, err, "should return created node version") + + nodesToScans, err := client.NodeVersion.Query().Where(nodeversion.StatusEQ(schema.NodeVersionStatusPending)).Count(ctx) + require.NoError(t, err) + + newNodeScanned := false + nodesScanned := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + + req := dripservices_registry.ScanRequest{} + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + if downloadUrl == req.URL { + newNodeScanned = true + } + nodesScanned++ + })) + t.Cleanup(s.Close) + + impl, authz := newMockedImpl(client, &config.Config{SecretScannerURL: s.URL}) + dur := time.Duration(0) + scanres, err := withMiddleware(authz, impl.SecurityScan)(ctx, drip.SecurityScanRequestObject{ + Params: drip.SecurityScanParams{ + MinAge: &dur, + }, + }) + require.NoError(t, err) + require.IsType(t, drip.SecurityScan200Response{}, scanres) + assert.True(t, newNodeScanned) + assert.Equal(t, nodesToScans, nodesScanned) + }) + } func isTokenMasked(token string) bool { diff --git a/integration-tests/test_util.go b/integration-tests/test_util.go index ecdccd5..5dad3de 100644 --- a/integration-tests/test_util.go +++ b/integration-tests/test_util.go @@ -4,14 +4,21 @@ import ( "context" "fmt" "net" + "net/http" + "net/http/httptest" + "reflect" + "registry-backend/drip" + auth "registry-backend/server/middleware/authentication" + "runtime" + "strings" "registry-backend/ent" "registry-backend/ent/migrate" - auth "registry-backend/server/middleware" "testing" "time" "github.com/google/uuid" + "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" @@ -29,6 +36,16 @@ func createTestUser(ctx context.Context, client *ent.Client) *ent.User { SaveX(ctx) } +func createAdminUser(ctx context.Context, client *ent.Client) *ent.User { + return client.User.Create(). + SetID(uuid.New().String()). + SetIsApproved(true). + SetIsAdmin(true). + SetName("admin"). + SetEmail("admin@gmail.com"). + SaveX(ctx) +} + func decorateUserInContext(ctx context.Context, user *ent.User) context.Context { return context.WithValue(ctx, auth.UserContextKey, &auth.UserDetails{ ID: user.ID, @@ -37,7 +54,7 @@ func decorateUserInContext(ctx context.Context, user *ent.User) context.Context }) } -func setupDB(t *testing.T, ctx context.Context) (*ent.Client, *postgres.PostgresContainer) { +func setupDB(t *testing.T, ctx context.Context) (client *ent.Client, cleanup func()) { // Define Postgres container request postgresContainer, err := postgres.RunContainer(ctx, testcontainers.WithImage("docker.io/postgres:15.2-alpine"), @@ -69,7 +86,7 @@ func setupDB(t *testing.T, ctx context.Context) (*ent.Client, *postgres.Postgres t.Fatalf("Failed to start container: %s", err) } - client, err := ent.Open("postgres", databaseURL) + client, err = ent.Open("postgres", databaseURL) if err != nil { log.Ctx(ctx).Fatal().Err(err).Msg("failed opening connection to postgres") } @@ -81,7 +98,13 @@ func setupDB(t *testing.T, ctx context.Context) (*ent.Client, *postgres.Postgres } println("Schema created") - return client, postgresContainer + + cleanup = func() { + if err := postgresContainer.Terminate(ctx); err != nil { + log.Ctx(ctx).Error().Msgf("failed to terminate container: %s", err) + } + } + return } func waitPortOpen(t *testing.T, host string, port string, timeout time.Duration) { @@ -107,5 +130,27 @@ func waitPortOpen(t *testing.T, host string, port string, timeout time.Duration) conn.Close() return } +} + +func withMiddleware[R any, S any](mw drip.StrictMiddlewareFunc, h func(ctx context.Context, req R) (res S, err error)) func(ctx context.Context, req R) (res S, err error) { + handler := func(ctx echo.Context, request interface{}) (interface{}, error) { + return h(ctx.Request().Context(), request.(R)) + } + nameA := strings.Split(runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name(), ".") + nameA = strings.Split(nameA[len(nameA)-1], "-") + opname := nameA[0] + + return func(ctx context.Context, req R) (res S, err error) { + fakeReq := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) + fakeRes := httptest.NewRecorder() + fakeCtx := echo.New().NewContext(fakeReq, fakeRes) + + f := mw(handler, opname) + r, err := f(fakeCtx, req) + if r == nil { + return *new(S), err + } + return r.(S), err + } } diff --git a/main.go b/main.go index 9dbd537..5d94045 100644 --- a/main.go +++ b/main.go @@ -21,8 +21,12 @@ func main() { connection_string := os.Getenv("DB_CONNECTION_STRING") config := config.Config{ - ProjectID: os.Getenv("PROJECT_ID"), - DripEnv: os.Getenv("DRIP_ENV"), + ProjectID: os.Getenv("PROJECT_ID"), + DripEnv: os.Getenv("DRIP_ENV"), + SlackRegistryChannelWebhook: os.Getenv("SLACK_REGISTRY_CHANNEL_WEBHOOK"), + JWTSecret: os.Getenv("JWT_SECRET"), + SecretScannerURL: os.Getenv("SECRET_SCANNER_URL"), + DiscordSecurityChannelWebhook: os.Getenv("SECURITY_COUNCIL_DISCORD_WEBHOOK"), } var dsn string @@ -39,7 +43,7 @@ func main() { } defer client.Close() // Run the auto migration tool for localdev. - if os.Getenv("DRIP_ENV") == "localdev" || os.Getenv("DRIP_ENV") == "staging" { + if os.Getenv("DRIP_ENV") == "localdev" { log.Info().Msg("Running migrations") if err := client.Schema.Create(context.Background(), migrate.WithDropIndex(true), migrate.WithDropColumn(true)); err != nil { @@ -48,5 +52,5 @@ func main() { } server := server.NewServer(client, &config) - server.Start() + log.Fatal().Err(server.Start()).Msg("Server stopped") } diff --git a/mapper/context.go b/mapper/context.go index 4c0d6c0..b1fe4db 100644 --- a/mapper/context.go +++ b/mapper/context.go @@ -3,7 +3,7 @@ package mapper import ( "context" "errors" - auth "registry-backend/server/middleware" + auth "registry-backend/server/middleware/authentication" ) func GetUserIDFromContext(ctx context.Context) (string, error) { diff --git a/mapper/error.go b/mapper/error.go new file mode 100644 index 0000000..2fc3faf --- /dev/null +++ b/mapper/error.go @@ -0,0 +1,41 @@ +package mapper + +import "errors" + +type errCode uint8 + +const ( + errCodeUnknown errCode = iota + errCodeBadRequest +) + +type codedError struct { + code errCode + msg string + err []error +} + +func (e codedError) Error() string { + return e.msg +} + +func (e codedError) Unwrap() []error { + return e.err +} + +func isCodedError(err error, code errCode) bool { + var e codedError + if !errors.As(err, &e) { + return false + } + + return e.code == code +} + +func NewErrorBadRequest(msg string, err ...error) error { + return codedError{code: errCodeBadRequest, msg: msg, err: err} +} + +func IsErrorBadRequest(err error) (y bool) { + return isCodedError(err, errCodeBadRequest) +} diff --git a/mapper/node.go b/mapper/node.go index 118151d..1a5aecf 100644 --- a/mapper/node.go +++ b/mapper/node.go @@ -1,10 +1,10 @@ package mapper import ( - "fmt" "regexp" "registry-backend/drip" "registry-backend/ent" + "registry-backend/ent/schema" "strings" ) @@ -27,6 +27,9 @@ func ApiCreateNodeToDb(publisherId string, node *drip.Node, client *ent.Client) if node.Name != nil { newNode.SetName(*node.Name) } + if node.Category != nil { + newNode.SetCategory(*node.Category) + } if node.Tags != nil { newNode.SetTags(*node.Tags) } @@ -57,6 +60,9 @@ func ApiUpdateNodeToUpdateFields(nodeID string, node *drip.Node, client *ent.Cli if node.Tags != nil { update.SetTags(*node.Tags) } + if node.Category != nil { + update.SetCategory(*node.Category) + } if node.Repository != nil { update.SetRepositoryURL(*node.Repository) } @@ -70,25 +76,39 @@ func ApiUpdateNodeToUpdateFields(nodeID string, node *drip.Node, client *ent.Cli func ValidateNode(node *drip.Node) error { if node.Id != nil { if len(*node.Id) > 100 { - return fmt.Errorf("node id is too long") + return NewErrorBadRequest("node id is too long") } - if !IsValidNodeID(*node.Id) { - return fmt.Errorf("invalid node id") + isValid, msg := IsValidNodeID(*node.Id) + if !isValid { + return NewErrorBadRequest(msg) + } + } + if node.Description != nil { + if len(*node.Description) > 1000 { + return NewErrorBadRequest("description is too long") } } return nil } -func IsValidNodeID(nodeID string) bool { +func IsValidNodeID(nodeID string) (bool, string) { if len(nodeID) == 0 || len(nodeID) > 50 { - return false + return false, "node id must be between 1 and 50 characters" + } + // Check if there are capital letters in the string + if strings.ToLower(nodeID) != nodeID { + return false, "Node ID can only contain lowercase letters" } // Regular expression pattern for Node ID validation (lowercase letters only) pattern := `^[a-z][a-z0-9-_]+(\.[a-z0-9-_]+)*$` // Compile the regular expression pattern regex := regexp.MustCompile(pattern) // Check if the string matches the pattern - return regex.MatchString(nodeID) + matches := regex.MatchString(nodeID) + if !matches { + return false, "Node ID can only contain lowercase letters, numbers, hyphens, underscores, and dots. Dots cannot be consecutive or be at the start or end of the id." + } + return true, "" } func DbNodeToApiNode(node *ent.Node) *drip.Node { @@ -96,14 +116,42 @@ func DbNodeToApiNode(node *ent.Node) *drip.Node { return nil } + downloads := int(node.TotalInstall) + rate := float32(0) + if node.TotalReview > 0 { + rate = float32(node.TotalStar) / float32(node.TotalReview) + } + return &drip.Node{ - Author: &node.Author, - Description: &node.Description, - Id: &node.ID, - License: &node.License, - Name: &node.Name, - Tags: &node.Tags, - Repository: &node.RepositoryURL, - Icon: &node.IconURL, + Author: &node.Author, + Description: &node.Description, + Category: &node.Category, + Id: &node.ID, + License: &node.License, + Name: &node.Name, + Tags: &node.Tags, + Repository: &node.RepositoryURL, + Icon: &node.IconURL, + Downloads: &downloads, + Rating: &rate, + Status: DbNodeStatusToApiNodeStatus(node.Status), + StatusDetail: &node.StatusDetail, } } + +func DbNodeStatusToApiNodeStatus(status schema.NodeStatus) *drip.NodeStatus { + var nodeStatus drip.NodeStatus + + switch status { + case schema.NodeStatusActive: + nodeStatus = drip.NodeStatusActive + case schema.NodeStatusBanned: + nodeStatus = drip.NodeStatusBanned + case schema.NodeStatusDeleted: + nodeStatus = drip.NodeStatusDeleted + default: + nodeStatus = "" + } + + return &nodeStatus +} diff --git a/mapper/node_test.go b/mapper/node_test.go index 1bcb0fd..29da250 100644 --- a/mapper/node_test.go +++ b/mapper/node_test.go @@ -8,6 +8,7 @@ import ( // TestIsValidNodeID tests the isValidNodeID function with various inputs. func TestIsValidNodeID(t *testing.T) { + regexErrorMessage := "Node ID can only contain lowercase letters, numbers, hyphens, underscores, and dots. Dots cannot be consecutive or be at the start or end of the id." testCases := []struct { name string node *drip.Node @@ -26,7 +27,7 @@ func TestIsValidNodeID(t *testing.T) { { name: "Invalid Node ID", node: &drip.Node{Id: stringPtr("123")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { @@ -42,38 +43,33 @@ func TestIsValidNodeID(t *testing.T) { { name: "Invalid with uppercase", node: &drip.Node{Id: stringPtr("Node")}, - expectedError: "invalid node id", + expectedError: "Node ID can only contain lowercase letters", }, { name: "Invalid with special characters", node: &drip.Node{Id: stringPtr("node_@")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Invalid start with number", node: &drip.Node{Id: stringPtr("1node")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Invalid start with dash", node: &drip.Node{Id: stringPtr("-node")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Empty input", node: &drip.Node{Id: stringPtr("")}, - expectedError: "invalid node id", + expectedError: "node id must be between 1 and 50 characters", }, { name: "Valid all lowercase letters", node: &drip.Node{Id: stringPtr("abcdefghijklmnopqrstuvwxyz")}, expectedError: "", }, - { - name: "Valid all uppercase letters", - node: &drip.Node{Id: stringPtr("ABCD")}, - expectedError: "invalid node id", - }, { name: "Valid containing underscore", node: &drip.Node{Id: stringPtr("comfy_ui")}, @@ -97,17 +93,17 @@ func TestIsValidNodeID(t *testing.T) { { name: "Invalid ID with number first", node: &drip.Node{Id: stringPtr("1invalidnodeid")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Invalid ID with consecutive dots", node: &drip.Node{Id: stringPtr("invalid..nodeid")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Invalid ID with special character first", node: &drip.Node{Id: stringPtr("-invalidnodeid")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Valid complex ID", @@ -117,17 +113,17 @@ func TestIsValidNodeID(t *testing.T) { { name: "Invalid ID with special characters only", node: &drip.Node{Id: stringPtr("$$$$")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Invalid ID with leading dot", node: &drip.Node{Id: stringPtr(".invalidnodeid")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, { name: "Invalid ID with ending dot", node: &drip.Node{Id: stringPtr("invalidnodeid.")}, - expectedError: "invalid node id", + expectedError: regexErrorMessage, }, } diff --git a/mapper/node_version.go b/mapper/node_version.go index 6931453..a7f822f 100644 --- a/mapper/node_version.go +++ b/mapper/node_version.go @@ -4,6 +4,7 @@ import ( "fmt" "registry-backend/drip" "registry-backend/ent" + "registry-backend/ent/schema" "github.com/Masterminds/semver/v3" "github.com/google/uuid" @@ -50,31 +51,86 @@ func DbNodeVersionToApiNodeVersion(dbNodeVersion *ent.NodeVersion) *drip.NodeVer if dbNodeVersion == nil { return nil } - id := dbNodeVersion.ID.String() - if dbNodeVersion.Edges.StorageFile == nil { - return &drip.NodeVersion{ - Id: &id, - Version: &dbNodeVersion.Version, - Changelog: &dbNodeVersion.Changelog, - Deprecated: &dbNodeVersion.Deprecated, - Dependencies: &dbNodeVersion.PipDependencies, - CreatedAt: &dbNodeVersion.CreateTime, - } + id := dbNodeVersion.ID.String() + var downloadUrl string + status := DbNodeVersionStatusToApiNodeVersionStatus(dbNodeVersion.Status) + if dbNodeVersion.Edges.StorageFile != nil { + downloadUrl = dbNodeVersion.Edges.StorageFile.FileURL } - return &drip.NodeVersion{ + apiVersion := &drip.NodeVersion{ Id: &id, Version: &dbNodeVersion.Version, Changelog: &dbNodeVersion.Changelog, - DownloadUrl: &dbNodeVersion.Edges.StorageFile.FileURL, Deprecated: &dbNodeVersion.Deprecated, Dependencies: &dbNodeVersion.PipDependencies, CreatedAt: &dbNodeVersion.CreateTime, + Status: status, + StatusReason: &dbNodeVersion.StatusReason, + DownloadUrl: &downloadUrl, } + return apiVersion } func CheckValidSemv(version string) bool { _, err := semver.NewVersion(version) return err == nil } + +func DbNodeVersionStatusToApiNodeVersionStatus(status schema.NodeVersionStatus) *drip.NodeVersionStatus { + var nodeVersionStatus drip.NodeVersionStatus + + switch status { + case schema.NodeVersionStatusActive: + nodeVersionStatus = drip.NodeVersionStatusActive + case schema.NodeVersionStatusBanned: + nodeVersionStatus = drip.NodeVersionStatusBanned + case schema.NodeVersionStatusDeleted: + nodeVersionStatus = drip.NodeVersionStatusDeleted + case schema.NodeVersionStatusPending: + nodeVersionStatus = drip.NodeVersionStatusPending + case schema.NodeVersionStatusFlagged: + nodeVersionStatus = drip.NodeVersionStatusFlagged + default: + nodeVersionStatus = "" + } + + return &nodeVersionStatus +} + +func ApiNodeVersionStatusesToDbNodeVersionStatuses(status *[]drip.NodeVersionStatus) []schema.NodeVersionStatus { + var nodeVersionStatus []schema.NodeVersionStatus + + if status == nil { + return nodeVersionStatus + } + + for _, s := range *status { + dbNodeVersion := ApiNodeVersionStatusToDbNodeVersionStatus(s) + nodeVersionStatus = append(nodeVersionStatus, dbNodeVersion) + } + + return nodeVersionStatus +} + +func ApiNodeVersionStatusToDbNodeVersionStatus(status drip.NodeVersionStatus) schema.NodeVersionStatus { + var nodeVersionStatus schema.NodeVersionStatus + + switch status { + case drip.NodeVersionStatusActive: + nodeVersionStatus = schema.NodeVersionStatusActive + case drip.NodeVersionStatusBanned: + nodeVersionStatus = schema.NodeVersionStatusBanned + case drip.NodeVersionStatusDeleted: + nodeVersionStatus = schema.NodeVersionStatusDeleted + case drip.NodeVersionStatusPending: + nodeVersionStatus = schema.NodeVersionStatusPending + case drip.NodeVersionStatusFlagged: + nodeVersionStatus = schema.NodeVersionStatusFlagged + default: + nodeVersionStatus = "" + } + + return nodeVersionStatus +} diff --git a/mapper/publisher.go b/mapper/publisher.go index 6a02a00..b290b5b 100644 --- a/mapper/publisher.go +++ b/mapper/publisher.go @@ -5,6 +5,7 @@ import ( "regexp" "registry-backend/drip" "registry-backend/ent" + "registry-backend/ent/schema" ) func ApiCreatePublisherToDb(publisher *drip.Publisher, client *ent.Client) (*ent.PublisherCreate, error) { @@ -112,9 +113,23 @@ func DbPublisherToApiPublisher(publisher *ent.Publisher, public bool) *drip.Publ Website: &publisher.Website, CreatedAt: &publisher.CreateTime, Members: &members, + Status: DbPublisherStatusToApiPublisherStatus(publisher.Status), } } func ToStringPointer(s string) *string { return &s } + +func DbPublisherStatusToApiPublisherStatus(status schema.PublisherStatusType) *drip.PublisherStatus { + var apiStatus drip.PublisherStatus + switch status { + case schema.PublisherStatusTypeActive: + apiStatus = drip.PublisherStatusActive + case schema.PublisherStatusTypeBanned: + apiStatus = drip.PublisherStatusBanned + default: + apiStatus = "" + } + return &apiStatus +} diff --git a/mapper/workflow_run.go b/mapper/workflow_run.go new file mode 100644 index 0000000..88ebf8d --- /dev/null +++ b/mapper/workflow_run.go @@ -0,0 +1,153 @@ +package mapper + +import ( + "fmt" + "registry-backend/drip" + "registry-backend/ent" + "registry-backend/ent/schema" +) + +func CiWorkflowResultsToActionJobResults(results []*ent.CIWorkflowResult) ([]drip.ActionJobResult, error) { + var jobResultsData []drip.ActionJobResult + + for _, result := range results { + jobResultData, err := CiWorkflowResultToActionJobResult(result) + if err != nil { + return nil, err + } + jobResultsData = append(jobResultsData, *jobResultData) + } + return jobResultsData, nil +} + +func CiWorkflowResultToActionJobResult(result *ent.CIWorkflowResult) (*drip.ActionJobResult, error) { + var storageFileData *drip.StorageFile + + // Check if the StorageFile slice is not empty before accessing + if len(result.Edges.StorageFile) > 0 { + storageFileData = &drip.StorageFile{ + PublicUrl: &result.Edges.StorageFile[0].FileURL, + } + } + commitId := result.Edges.Gitcommit.ID.String() + commitUnixTime := result.Edges.Gitcommit.CommitTimestamp.Unix() + apiStatus, err := DbWorkflowRunStatusToApi(result.Status) + if err != nil { + return nil, err + } + + machineStats, err := MapToMachineStats(result.Metadata) + if err != nil { + return nil, err + } + + return &drip.ActionJobResult{ + Id: &result.ID, + WorkflowName: &result.WorkflowName, + OperatingSystem: &result.OperatingSystem, + PythonVersion: &result.PythonVersion, + PytorchVersion: &result.PytorchVersion, + CudaVersion: &result.CudaVersion, + StorageFile: storageFileData, + CommitHash: &result.Edges.Gitcommit.CommitHash, + CommitId: &commitId, + CommitTime: &commitUnixTime, + CommitMessage: &result.Edges.Gitcommit.CommitMessage, + GitRepo: &result.Edges.Gitcommit.RepoName, + ActionRunId: &result.RunID, + ActionJobId: &result.JobID, + StartTime: &result.StartTime, + EndTime: &result.EndTime, + JobTriggerUser: &result.JobTriggerUser, + AvgVram: &result.AvgVram, + PeakVram: &result.PeakVram, + ComfyRunFlags: &result.ComfyRunFlags, + + Status: &apiStatus, + PrNumber: &result.Edges.Gitcommit.PrNumber, + Author: &result.Edges.Gitcommit.Author, + MachineStats: machineStats, + }, nil +} + +func ApiWorkflowRunStatusToDb(status drip.WorkflowRunStatus) (schema.WorkflowRunStatusType, error) { + switch status { + case drip.WorkflowRunStatusStarted: + return schema.WorkflowRunStatusTypeStarted, nil + case drip.WorkflowRunStatusCompleted: + return schema.WorkflowRunStatusTypeCompleted, nil + case drip.WorkflowRunStatusFailed: + return schema.WorkflowRunStatusTypeFailed, nil + default: + // Throw an error + return "", fmt.Errorf("unsupported workflow status: %v", status) + + } +} + +func DbWorkflowRunStatusToApi(status schema.WorkflowRunStatusType) (drip.WorkflowRunStatus, error) { + switch status { + case schema.WorkflowRunStatusTypeStarted: + return drip.WorkflowRunStatusStarted, nil + case schema.WorkflowRunStatusTypeCompleted: + return drip.WorkflowRunStatusCompleted, nil + case schema.WorkflowRunStatusTypeFailed: + return drip.WorkflowRunStatusFailed, nil + default: + // Throw an error + return "", fmt.Errorf("unsupported workflow status: %v", status) + } +} + +func MachineStatsToMap(ms *drip.MachineStats) map[string]interface{} { + return map[string]interface{}{ + "CpuCapacity": ms.CpuCapacity, + "DiskCapacity": ms.DiskCapacity, + "InitialCpu": ms.InitialCpu, + "InitialDisk": ms.InitialDisk, + "InitialRam": ms.InitialRam, + "MemoryCapacity": ms.MemoryCapacity, + "OsVersion": ms.OsVersion, + "PipFreeze": ms.PipFreeze, + "VramTimeSeries": ms.VramTimeSeries, + "MachineName": ms.MachineName, + "GpuType": ms.GpuType, + } +} + +func MapToMachineStats(data map[string]interface{}) (*drip.MachineStats, error) { + var ms drip.MachineStats + + if data == nil { + return nil, nil + } + + // Helper function to get string pointers from the map + getStringPtr := func(key string) *string { + if val, exists := data[key]; exists { + if strVal, ok := val.(string); ok { + return &strVal + } + } + return nil // Return nil if the key does not exist or type assertion fails + } + + ms.CpuCapacity = getStringPtr("CpuCapacity") + ms.DiskCapacity = getStringPtr("DiskCapacity") + ms.InitialCpu = getStringPtr("InitialCpu") + ms.InitialDisk = getStringPtr("InitialDisk") + ms.InitialRam = getStringPtr("InitialRam") + ms.MemoryCapacity = getStringPtr("MemoryCapacity") + ms.OsVersion = getStringPtr("OsVersion") + ms.PipFreeze = getStringPtr("PipFreeze") + ms.MachineName = getStringPtr("MachineName") + ms.GpuType = getStringPtr("GpuType") + + if val, exists := data["VramTimeSeries"]; exists { + if vram, ok := val.(map[string]interface{}); ok { + ms.VramTimeSeries = &vram + } + } + + return &ms, nil +} diff --git a/mock/gateways/mock_algolia_service.go b/mock/gateways/mock_algolia_service.go new file mode 100644 index 0000000..8f93b9f --- /dev/null +++ b/mock/gateways/mock_algolia_service.go @@ -0,0 +1,33 @@ +package gateways + +import ( + "context" + "registry-backend/ent" + "registry-backend/gateways/algolia" + + "github.com/stretchr/testify/mock" +) + +var _ algolia.AlgoliaService = &MockAlgoliaService{} + +type MockAlgoliaService struct { + mock.Mock +} + +// DeleteNode implements algolia.AlgoliaService. +func (m *MockAlgoliaService) DeleteNode(ctx context.Context, n *ent.Node) error { + args := m.Called(ctx, n) + return args.Error(0) +} + +// IndexNodes implements algolia.AlgoliaService. +func (m *MockAlgoliaService) IndexNodes(ctx context.Context, n ...*ent.Node) error { + args := m.Called(ctx, n) + return args.Error(0) +} + +// SearchNodes implements algolia.AlgoliaService. +func (m *MockAlgoliaService) SearchNodes(ctx context.Context, query string, opts ...interface{}) (nodes []*ent.Node, err error) { + args := m.Called(ctx, query, opts) + return args.Get(0).([]*ent.Node), args.Error(1) +} diff --git a/mock/gateways/mock_discord_service.go b/mock/gateways/mock_discord_service.go new file mode 100644 index 0000000..4ce420a --- /dev/null +++ b/mock/gateways/mock_discord_service.go @@ -0,0 +1,14 @@ +package gateways + +import ( + "github.com/stretchr/testify/mock" +) + +type MockDiscordService struct { + mock.Mock +} + +func (m *MockDiscordService) SendSecurityCouncilMessage(msg string) error { + args := m.Called(msg) + return args.Error(0) +} diff --git a/openapi.yml b/openapi.yml index 90327c4..50d8ee9 100644 --- a/openapi.yml +++ b/openapi.yml @@ -57,6 +57,9 @@ paths: comfy_logs_gcs_path: type: string description: The path to ComfyUI logs. eg. gs://bucket-name/logs + comfy_run_flags: + type: string + description: The flags used in the comfy run commit_hash: type: string commit_time: @@ -78,12 +81,36 @@ paths: type: integer format: int64 description: The end time of the job as a Unix timestamp. + avg_vram: + type: integer + description: The average amount of VRAM used in the run. + peak_vram: + type: integer + description: The peak amount of VRAM used in the run. + pr_number: + type: string + description: The pull request number + author: + type: string + description: The author of the commit + job_trigger_user: + type: string + description: The user who triggered the job + python_version: + type: string + description: The python version used in the run + pytorch_version: + type: string + description: The pytorch version used in the run + machine_stats: + $ref: "#/components/schemas/MachineStats" + status: + $ref: "#/components/schemas/WorkflowRunStatus" required: - repo - job_id - run_id - os - - gpu_type - commit_hash - commit_time - commit_message @@ -91,6 +118,11 @@ paths: - workflow_name - start_time - end_time + - pr_number + - python_version + - job_trigger_user + - author + - status responses: '200': @@ -174,6 +206,35 @@ paths: description: Commit not found '500': description: Internal server error + /workflowresult/{workflowResultId}: + get: + summary: Retrieve a specific commit by ID + operationId: getWorkflowResult + parameters: + - in: path + name: workflowResultId + required: true + schema: + type: string + responses: + '200': + description: Commit details + content: + application/json: + schema: + $ref: '#/components/schemas/ActionJobResult' + '404': + description: Commit not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /branch: get: summary: Retrieve all distinct branches for a given repo @@ -480,6 +541,40 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /publishers/{publisherId}/ban: + post: + summary: Ban a publisher + operationId: BanPublisher + parameters: + - in: path + name: publisherId + required: true + schema: + type: string + responses: + '204': + description: Publisher Banned Successfully + '401': + description: Unauthorized + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Publisher not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /publishers/{publisherId}/nodes: post: summary: Create a new custom node @@ -534,6 +629,12 @@ paths: required: true schema: type: string + - in: query + name: include_banned + description: Number of nodes to return per page + required: false + schema: + type: boolean responses: '200': description: List of all nodes @@ -654,7 +755,7 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' - + /publishers/{publisherId}/nodes/{nodeId}/permissions: get: summary: Retrieve permissions the user has for a given publisher @@ -790,12 +891,24 @@ paths: responses: '204': description: Version unpublished (deleted) successfully + '403': + description: Version not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' '404': description: Version not found content: application/json: schema: $ref: '#/components/schemas/Error' + '500': + description: Version not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' put: summary: Update changelog and deprecation status of a node version operationId: updateNodeVersion @@ -860,6 +973,45 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /publishers/{publisherId}/nodes/{nodeId}/ban: + post: + summary: Ban a publisher's Node + operationId: BanPublisherNode + parameters: + - in: path + name: publisherId + required: true + schema: + type: string + - in: path + name: nodeId + required: true + schema: + type: string + responses: + '204': + description: Node Banned Successfully + '401': + description: Unauthorized + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Publisher or Node not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /publishers/{publisherId}/tokens: post: summary: Create a new personal access token @@ -993,6 +1145,95 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /nodes/search: + get: + summary: Retrieves a list of nodes + description: Returns a paginated list of nodes across all publishers. + operationId: searchNodes + tags: + - Nodes + parameters: + - in: query + name: page + description: Page number of the nodes list + required: false + schema: + type: integer + default: 1 + - in: query + name: limit + description: Number of nodes to return per page + required: false + schema: + type: integer + default: 10 + - in: query + name: search + description: Keyword to search the nodes + required: false + schema: + type: string + - in: query + name: include_banned + description: Number of nodes to return per page + required: false + schema: + type: boolean + responses: + '200': + description: A paginated list of nodes + content: + application/json: + schema: + type: object + properties: + total: + type: integer + description: Total number of nodes available + nodes: + type: array + items: + $ref: '#/components/schemas/Node' + page: + type: integer + description: Current page number + limit: + type: integer + description: Maximum number of nodes per page + totalPages: + type: integer + description: Total number of pages available + '400': + description: Invalid input, object invalid + '404': + description: Not found + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /nodes/reindex: + post: + summary: Reindex all nodes for searching. + operationId: reindexNodes + tags: + - Nodes + responses: + '200': + description: Reindex completed successfully. + '400': + description: Bad request. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /nodes: get: summary: Retrieves a list of nodes @@ -1015,6 +1256,12 @@ paths: schema: type: integer default: 10 + - in: query + name: include_banned + description: Number of nodes to return per page + required: false + schema: + type: boolean responses: '200': description: A paginated list of nodes @@ -1088,6 +1335,46 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' + /nodes/{nodeId}/reviews: + post: + summary: Add review to a specific version of a node + operationId: postNodeReview + tags: + - Nodes + parameters: + - in: path + name: nodeId + required: true + schema: + type: string + - in: query + name: star + description: number of star given to the node version + required: true + schema: + type: integer + responses: + '200': + description: Detailed information about a specific node + content: + application/json: + schema: + $ref: '#/components/schemas/Node' + '400': + description: Bad Request + '404': + description: Node version not found + content: + application/json: + schema: + $ref: '#/components/schemas/Error' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /nodes/{nodeId}/install: get: summary: Returns a node version to be installed. @@ -1153,6 +1440,13 @@ paths: required: true schema: type: string + - in: query + name: statuses + required: false + schema: + type: array + items: + $ref: '#/components/schemas/NodeVersionStatus' responses: '200': description: List of all node versions @@ -1162,6 +1456,12 @@ paths: type: array items: $ref: '#/components/schemas/NodeVersion' + '403': + description: Node banned + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' '404': description: Node not found content: @@ -1211,6 +1511,185 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /versions: + get: + summary: List all node versions given some filters. + operationId: listAllNodeVersions + tags: + - Versions + parameters: + - in: query + name: nodeId + required: false + schema: + type: string + - in: query + name: statuses + required: false + style: form + explode: true + schema: + type: array + items: + $ref: '#/components/schemas/NodeVersionStatus' + - in: query + name: page + required: false + schema: + type: integer + default: 1 + description: The page number to retrieve. + - in: query + name: pageSize + required: false + schema: + type: integer + default: 10 + description: The number of items to include per page. + responses: + '200': + description: List of all node versions + content: + application/json: + schema: + type: object + properties: + total: + type: integer + description: Total number of node versions available + versions: + type: array + items: + $ref: '#/components/schemas/NodeVersion' + page: + type: integer + description: Current page number + pageSize: + type: integer + description: Maximum number of node versions per page. Maximum is 100. + totalPages: + type: integer + description: Total number of pages available + + '403': + description: Node banned + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + + /admin/nodes/{nodeId}/versions/{versionNumber}: + put: + summary: Admin Update Node Version Status + operationId: adminUpdateNodeVersion + description: Only admins can approve a node version. + tags: + - Versions + security: + - BearerAuth: [ ] + parameters: + - in: path + name: nodeId + required: true + schema: + type: string + - in: path + name: versionNumber + required: true + schema: + type: string + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + status: + $ref: "#/components/schemas/NodeVersionStatus" + status_reason: + type: string + description: The reason for the status change. + + responses: + '200': + description: Version updated successfully + content: + application/json: + schema: + $ref: '#/components/schemas/NodeVersion' + '400': + description: Bad request, invalid input data. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Version not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /security-scan: + get: + summary: Security Scan + operationId: securityScan + description: Pull all pending node versions and conduct security scans. + parameters: + - in: query + name: minAge + required: false + schema: + type: string + x-go-type: time.Duration + - in: query + name: maxNodes + required: false + schema: + type: integer + responses: + '200': + description: Scan completed successfully + '400': + description: Bad request, invalid input data. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' components: schemas: PersonalAccessToken: @@ -1309,15 +1788,21 @@ components: operating_system: type: string description: Operating system used - gpu_type: + python_version: type: string - description: GPU type used + description: PyTorch version used pytorch_version: type: string description: PyTorch version used action_run_id: type: string description: Identifier of the run this result belongs to + action_job_id: + type: string + description: Identifier of the job this result belongs to + cuda_version: + type: string + description: CUDA version used commit_hash: type: string description: The hash of the commit @@ -1331,9 +1816,15 @@ components: commit_message: type: string description: The message of the commit + comfy_run_flags: + type: string + description: The comfy run flags. E.g. `--low-vram` git_repo: type: string description: The repository name + pr_number: + type: string + description: The pull request number start_time: type: integer format: int64 @@ -1342,6 +1833,22 @@ components: type: integer format: int64 description: The end time of the job as a Unix timestamp. + avg_vram: + type: integer + description: The average VRAM used by the job + peak_vram: + type: integer + description: The peak VRAM used by the job + job_trigger_user: + type: string + description: The user who triggered the job. + author: + type: string + description: The author of the commit + machine_stats: + $ref: "#/components/schemas/MachineStats" + status: + $ref: "#/components/schemas/WorkflowRunStatus" storage_file: $ref: "#/components/schemas/StorageFile" StorageFile: @@ -1385,6 +1892,9 @@ components: items: $ref: "#/components/schemas/PublisherMember" description: A list of members in the publisher. + status: + $ref: "#/components/schemas/PublisherStatus" + description: The status of the publisher. PublisherMember: type: object properties: @@ -1406,6 +1916,9 @@ components: name: type: string description: The display name of the node. + category: + type: string + description: The category of the node. description: type: string author: @@ -1435,6 +1948,12 @@ components: publisher: $ref: "#/components/schemas/Publisher" description: The publisher of the node. + status: + $ref: "#/components/schemas/NodeStatus" + description: The status of the node. + status_detail: + type: string + description: The status detail of the node. NodeVersion: type: object properties: @@ -1461,6 +1980,12 @@ components: deprecated: type: boolean description: Indicates if this version is deprecated. + status: + $ref: "#/components/schemas/NodeVersionStatus" + description: The status of the node version. + status_reason: + type: string + description: The reason for the status change. Error: type: object properties: @@ -1482,6 +2007,69 @@ components: deprecated: type: boolean description: Whether the version is deprecated. + # Enum of Node Status + NodeStatus: + type: string + enum: + - NodeStatusActive + - NodeStatusDeleted + - NodeStatusBanned + # Enum of Node Version Status + NodeVersionStatus: + type: string + enum: + - NodeVersionStatusActive + - NodeVersionStatusDeleted + - NodeVersionStatusBanned + - NodeVersionStatusPending + - NodeVersionStatusFlagged + PublisherStatus: + type: string + enum: + - PublisherStatusActive + - PublisherStatusBanned + WorkflowRunStatus: + type: string + enum: + - WorkflowRunStatusStarted + - WorkflowRunStatusFailed + - WorkflowRunStatusCompleted + MachineStats: + type: object + properties: + machine_name: + type: string + description: Name of the machine. + os_version: + type: string + description: The operating system version. eg. Ubuntu Linux 20.04 + gpu_type: + type: string + description: The GPU type. eg. NVIDIA Tesla K80 + cpu_capacity: + type: string + description: Total CPU on the machine. + initial_cpu: + type: string + description: Initial CPU available before the job starts. + memory_capacity: + type: string + description: Total memory on the machine. + initial_ram: + type: string + description: Initial RAM available before the job starts. + vram_time_series: + type: object + description: Time series of VRAM usage. + disk_capacity: + type: string + description: Total disk capacity on the machine. + initial_disk: + type: string + description: Initial disk available before the job starts. + pip_freeze: + type: string + description: The pip freeze output securitySchemes: BearerAuth: type: http diff --git a/run-service-prod.yaml b/run-service-prod.yaml index 878ea7c..2a09b70 100644 --- a/run-service-prod.yaml +++ b/run-service-prod.yaml @@ -22,8 +22,38 @@ spec: secretKeyRef: key: 1 name: PROD_SUPABASE_CONNECTION_STRING + - name: JWT_SECRET + valueFrom: + secretKeyRef: + key: 1 + name: PROD_JWT_SECRET + - name: SLACK_REGISTRY_CHANNEL_WEBHOOK + valueFrom: + secretKeyRef: + key: 1 + name: PROD_SLACK_REGISTRY_CHANNEL_WEBHOOK - name: PROJECT_ID value: dreamboothy # TODO(robinhuang): Switch to a list of strings - name: CORS_ORIGIN - value: https://comfyregistry.org \ No newline at end of file + value: https://comfyregistry.org + - name: SECRET_SCANNER_URL + valueFrom: + secretKeyRef: + key: 1 + name: SECURITY_SCANNER_CLOUD_FUNCTION_URL + - name: SECURITY_COUNCIL_DISCORD_WEBHOOK + valueFrom: + secretKeyRef: + key: 1 + name: SECURITY_COUNCIL_DISCORD_WEBHOOK + - name: ALGOLIA_APP_ID + valueFrom: + secretKeyRef: + key: 2 + name: PROD_ALGOLIA_APP_ID + - name: ALGOLIA_API_KEY + valueFrom: + secretKeyRef: + key: 2 + name: PROD_ALGOLIA_API_KEY diff --git a/run-service-staging.yaml b/run-service-staging.yaml index a882731..c3e76a4 100644 --- a/run-service-staging.yaml +++ b/run-service-staging.yaml @@ -24,8 +24,33 @@ spec: secretKeyRef: key: 1 name: STAGING_SUPABASE_CONNECTION_STRING + - name: JWT_SECRET + valueFrom: + secretKeyRef: + key: 1 + name: STAGING_JWT_SECRET - name: PROJECT_ID value: dreamboothy # TODO(robinhuang): Switch to a list of strings - name: CORS_ORIGIN - value: https://staging.comfyregistry.org \ No newline at end of file + value: https://staging.comfyregistry.org + - name: SECRET_SCANNER_URL + valueFrom: + secretKeyRef: + key: 1 + name: SECURITY_SCANNER_CLOUD_FUNCTION_URL + - name: SECURITY_COUNCIL_DISCORD_WEBHOOK + valueFrom: + secretKeyRef: + key: 1 + name: SECURITY_COUNCIL_DISCORD_WEBHOOK + - name: ALGOLIA_APP_ID + valueFrom: + secretKeyRef: + key: 1 + name: STAGING_ALGOLIA_APP_ID + - name: ALGOLIA_API_KEY + valueFrom: + secretKeyRef: + key: 1 + name: STAGING_ALGOLIA_API_KEY \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..3773e95 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,38 @@ +# Scripts + +## Create JWT Token + +This script is used to create JWT Token representing an existing user which can then be used to invoke the API Server as an alternative authentication method to firebase auth. + +```bash +export JWT_SECRET +go run ./scripts/create-jwt-token --user-id "" +``` + +Notes: + +1. You need to use the same values as the [JWT_SECRET](./../docker-compose.yml#L20) environment variables defined for the API Server. +2. By default, the token will [expire in 30 days](./create-jwt-token/main.go#L14), which can be overriden using `--expiry` flag. + +## Ban a Publisher + +This script is used to invoke ban publisher API using jwt token as authentication. + +```bash +go run ./scripts/ban-publisher \ + --base-url "" \ + --token "" \ + --publisher-id "" +``` + +## Ban a Node + +This script is used to invoke ban publisher API using jwt token as authentication. + +```bash +go run ./scripts/ban-node \ + --base-url "" \ + --token "" \ + --publisher-id "" + --node-id "" +``` diff --git a/scripts/ban-node/main.go b/scripts/ban-node/main.go new file mode 100644 index 0000000..98a0ba6 --- /dev/null +++ b/scripts/ban-node/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" +) + +var pubID = flag.String("publisher-id", "", "Pubisher ID of the node") +var nodeID = flag.String("node-id", "", "Node ID to be banned") +var token = flag.String("token", os.Getenv("JWT_TOKEN"), "JWT token to use for authentication") +var baseURL = flag.String("base-url", os.Getenv("REGISTRY_BASE_URL"), "Base url of registry service") + +func main() { + flag.Parse() + + if pubID == nil || *pubID == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--publisher-id` must be set to non-empty string.\n") + } + if nodeID == nil || *nodeID == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--node-id` must be set to non-empty string.\n") + } + if token == nil || *token == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--token` or environment variable `JWT_TOKEN` must be set to non-empty string.\n") + } + if baseURL == nil || *baseURL == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--base-url` or environment variable `BASE_URL` must be set to non-empty string.\n") + } + + u, err := url.Parse(*baseURL) + if err != nil { + log.Fatalf("Invalid base url :%v .\n", err) + } + u = u.JoinPath("publishers", *pubID, "nodes", *nodeID, "ban") + + req, _ := http.NewRequest(http.MethodPost, u.String(), nil) + req.Header.Set("Authorization", "Bearer "+*token) + res, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatalf("Failed to send HTTP request :%v\n", err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + log.Fatalf("Failed to read HTTP response :%v\n", err) + } + if res.StatusCode > 299 { + log.Fatalf("Received non-success response: \nStatus: %d\nBody: \n\n%v\n", res.StatusCode, string(b)) + } + fmt.Printf("Publisher's node '%s' has been banned\n", *pubID) + +} diff --git a/scripts/ban-publisher/main.go b/scripts/ban-publisher/main.go new file mode 100644 index 0000000..acff32c --- /dev/null +++ b/scripts/ban-publisher/main.go @@ -0,0 +1,54 @@ +package main + +import ( + "flag" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" +) + +var pubID = flag.String("publisher-id", "", "Pubisher ID that is represented by the JWT token") +var token = flag.String("token", os.Getenv("JWT_TOKEN"), "JWT token to use for authentication") +var baseURL = flag.String("base-url", os.Getenv("REGISTRY_BASE_URL"), "Base url of registry service") + +func main() { + flag.Parse() + + if pubID == nil || *pubID == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--publisher-id` must be set to non-empty string.\n") + } + if token == nil || *token == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--token` or environment variable `JWT_TOKEN` must be set to non-empty string.\n") + } + if baseURL == nil || *baseURL == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--base-url` or environment variable `BASE_URL` must be set to non-empty string.\n") + } + + u, err := url.Parse(*baseURL) + if err != nil { + log.Fatalf("Invalid base url :%v .\n", err) + } + u = u.JoinPath("publishers", *pubID, "ban") + + req, _ := http.NewRequest(http.MethodPost, u.String(), nil) + req.Header.Set("Authorization", "Bearer "+*token) + res, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatalf("Failed to send HTTP request :%v\n", err) + } + b, err := io.ReadAll(res.Body) + if err != nil { + log.Fatalf("Failed to read HTTP response :%v\n", err) + } + if res.StatusCode > 299 { + log.Fatalf("Received non-success response: \nStatus: %d\nBody: \n\n%v\n", res.StatusCode, string(b)) + } + fmt.Printf("Publisher '%s' has been banned\n", *pubID) + +} diff --git a/scripts/create-jwt-token/main.go b/scripts/create-jwt-token/main.go new file mode 100644 index 0000000..de57cbe --- /dev/null +++ b/scripts/create-jwt-token/main.go @@ -0,0 +1,40 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var userID = flag.String("user-id", "", "User ID that is represented by the JWT token") +var expire = flag.Duration("expire", 30*24*time.Hour, "Expiry time of the JWT token") + +func main() { + secret, ok := os.LookupEnv("JWT_SECRET") + if !ok { + log.Fatalf("Environment variablel `JWT_SECRET` must be defined.\n") + } + + flag.Parse() + if userID == nil || *userID == "" { + flag.PrintDefaults() + log.Fatalf("Flag `--user-id` must be set to non-empty string.\n") + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{ + Subject: *userID, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(*expire)), + NotBefore: jwt.NewNumericDate(time.Now()), + }) + + tokenString, err := token.SignedString([]byte(secret)) + if err != nil { + log.Fatalf("Fail to create jwt token: %v .\n", err) + } + + fmt.Println(tokenString) +} diff --git a/server/implementation/api.implementation.go b/server/implementation/api.implementation.go index bbccb9f..28e5c80 100644 --- a/server/implementation/api.implementation.go +++ b/server/implementation/api.implementation.go @@ -3,23 +3,24 @@ package implementation import ( "registry-backend/config" "registry-backend/ent" + "registry-backend/gateways/algolia" + "registry-backend/gateways/discord" gateway "registry-backend/gateways/slack" "registry-backend/gateways/storage" dripservices_comfyci "registry-backend/services/comfy_ci" - dripservices_registry "registry-backend/services/registry" + dripservices "registry-backend/services/registry" ) type DripStrictServerImplementation struct { Client *ent.Client ComfyCIService *dripservices_comfyci.ComfyCIService - RegistryService *dripservices_registry.RegistryService + RegistryService *dripservices.RegistryService } -func NewStrictServerImplementation(client *ent.Client, config *config.Config, storageService storage.StorageService, slackService gateway.SlackService) *DripStrictServerImplementation { - +func NewStrictServerImplementation(client *ent.Client, config *config.Config, storageService storage.StorageService, slackService gateway.SlackService, discordService discord.DiscordService, algolia algolia.AlgoliaService) *DripStrictServerImplementation { return &DripStrictServerImplementation{ Client: client, ComfyCIService: dripservices_comfyci.NewComfyCIService(config), - RegistryService: dripservices_registry.NewRegistryService(storageService, slackService), + RegistryService: dripservices.NewRegistryService(storageService, slackService, discordService, algolia, config), } } diff --git a/server/implementation/cicd.go b/server/implementation/cicd.go index 7e52045..5f4504f 100644 --- a/server/implementation/cicd.go +++ b/server/implementation/cicd.go @@ -3,9 +3,9 @@ package implementation import ( "context" "registry-backend/drip" - "registry-backend/ent" "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" + "registry-backend/mapper" "strings" "entgo.io/ent/dialect/sql" @@ -61,7 +61,7 @@ func (impl *DripStrictServerImplementation) GetGitcommit(ctx context.Context, re // Conditionally add the commitId filter if commitId != uuid.Nil { - log.Ctx(ctx).Info().Msgf("Filtering git commit by commit hash %s", commitId) + log.Ctx(ctx).Info().Msgf("Filtering git commit by db commit id %s", commitId) query.Where(ciworkflowresult.HasGitcommitWith(gitcommit.IDEQ(commitId))) } @@ -84,6 +84,7 @@ func (impl *DripStrictServerImplementation) GetGitcommit(ctx context.Context, re count, err := query.Count(ctx) log.Ctx(ctx).Info().Msgf("Got %d runs", count) if err != nil { + log.Ctx(ctx).Error().Msgf("Error retrieving count of git commits w/ err: %v", err) return drip.GetGitcommit500Response{}, err } @@ -103,43 +104,44 @@ func (impl *DripStrictServerImplementation) GetGitcommit(ctx context.Context, re // Execute the query runs, err := query.All(ctx) if err != nil { + log.Ctx(ctx).Error().Msgf("Error retrieving git commits w/ err: %v", err) return drip.GetGitcommit500Response{}, err } - results := mapRunsToResponse(runs) + results, err := mapper.CiWorkflowResultsToActionJobResults(runs) + if err != nil { + log.Ctx(ctx).Error().Msgf("Error mapping git commits to action job results w/ err: %v", err) + return drip.GetGitcommit500Response{}, err + } + log.Ctx(ctx).Info().Msgf("Git commits retrieved successfully") return drip.GetGitcommit200JSONResponse{ JobResults: &results, TotalNumberOfPages: &numberOfPages, }, nil } -func mapRunsToResponse(results []*ent.CIWorkflowResult) []drip.ActionJobResult { - var jobResultsData []drip.ActionJobResult - - for _, result := range results { - storageFileData := drip.StorageFile{ - PublicUrl: &result.Edges.StorageFile.FileURL, - } - commitId := result.Edges.Gitcommit.ID.String() - commitUnixTime := result.Edges.Gitcommit.CommitTimestamp.Unix() - jobResultData := drip.ActionJobResult{ - WorkflowName: &result.WorkflowName, - OperatingSystem: &result.OperatingSystem, - GpuType: &result.GpuType, - PytorchVersion: &result.PytorchVersion, - StorageFile: &storageFileData, - CommitHash: &result.Edges.Gitcommit.CommitHash, - CommitId: &commitId, - CommitTime: &commitUnixTime, - CommitMessage: &result.Edges.Gitcommit.CommitMessage, - GitRepo: &result.Edges.Gitcommit.RepoName, - ActionRunId: &result.RunID, - StartTime: &result.StartTime, - EndTime: &result.EndTime, - } - jobResultsData = append(jobResultsData, jobResultData) - } - return jobResultsData +func (impl *DripStrictServerImplementation) GetWorkflowResult(ctx context.Context, request drip.GetWorkflowResultRequestObject) (drip.GetWorkflowResultResponseObject, error) { + log.Ctx(ctx).Info().Msgf("Getting workflow result with ID %s", request.WorkflowResultId) + workflowId := uuid.MustParse(request.WorkflowResultId) + workflow, err := impl.Client.CIWorkflowResult.Query().WithGitcommit().WithStorageFile().Where(ciworkflowresult.IDEQ(workflowId)).First(ctx) + + if err != nil { + log.Ctx(ctx).Error().Msgf("Error retrieving workflow result w/ err: %v", err) + return drip.GetWorkflowResult500JSONResponse{ + Message: err.Error(), + }, nil + } + + result, err := mapper.CiWorkflowResultToActionJobResult(workflow) + if err != nil { + log.Ctx(ctx).Error().Msgf("Error mapping workflow result to action job result w/ err: %v", err) + return drip.GetWorkflowResult500JSONResponse{ + Message: err.Error(), + }, nil + } + + log.Ctx(ctx).Info().Msgf("Workflow result retrieved successfully") + return drip.GetWorkflowResult200JSONResponse(*result), nil } func (impl *DripStrictServerImplementation) GetBranch(ctx context.Context, request drip.GetBranchRequestObject) (drip.GetBranchResponseObject, error) { @@ -151,17 +153,21 @@ func (impl *DripStrictServerImplementation) GetBranch(ctx context.Context, reque GroupBy(gitcommit.FieldBranchName). Strings(ctx) if err != nil { + log.Ctx(ctx).Error().Msgf("Error retrieving git's branchs w/ err: %v", err) return drip.GetBranch500Response{}, err } + log.Ctx(ctx).Info().Msgf("Git branches from '%s' repo retrieved successfully", request.Params.RepoName) return drip.GetBranch200JSONResponse{Branches: &branches}, nil } func (impl *DripStrictServerImplementation) PostUploadArtifact(ctx context.Context, request drip.PostUploadArtifactRequestObject) (drip.PostUploadArtifactResponseObject, error) { err := impl.ComfyCIService.ProcessCIRequest(ctx, impl.Client, &request) if err != nil { - log.Ctx(ctx).Error().Err(err).Msg("failed to process CI request") + log.Ctx(ctx).Error().Msgf("Error processiong CI request w/ err: %v", err) return drip.PostUploadArtifact500Response{}, err } + + log.Ctx(ctx).Info().Msgf("CI request with job id '%s' processed successfully", request.Body.JobId) return drip.PostUploadArtifact200JSONResponse{}, nil } diff --git a/server/implementation/registry.go b/server/implementation/registry.go index 7aa22ff..ba60887 100644 --- a/server/implementation/registry.go +++ b/server/implementation/registry.go @@ -8,6 +8,7 @@ import ( "registry-backend/ent/schema" "registry-backend/mapper" drip_services "registry-backend/services/registry" + "time" "github.com/google/uuid" "github.com/mixpanel/mixpanel-go" @@ -181,31 +182,6 @@ func (s *DripStrictServerImplementation) GetPublisher( func (s *DripStrictServerImplementation) UpdatePublisher( ctx context.Context, request drip.UpdatePublisherRequestObject) (drip.UpdatePublisherResponseObject, error) { log.Ctx(ctx).Info().Msgf("UpdatePublisher called with publisher ID: %s", request.PublisherId) - userId, err := mapper.GetUserIDFromContext(ctx) - if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.UpdatePublisher400JSONResponse{Message: "Invalid user ID"}, err - } - - log.Ctx(ctx).Info().Msgf("Checking if user ID %s has permission to update publisher ID %s", userId, request.PublisherId) - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.UpdatePublisher404JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf("Permission denied for user ID %s on "+ - "publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.UpdatePublisher401Response{}, err - - case err != nil: - log.Ctx(ctx).Error().Msgf( - "Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.UpdatePublisher500JSONResponse{ - Message: "Failed to assert publisher permission", Error: err.Error()}, err - } log.Ctx(ctx).Info().Msgf("Updating publisher with ID %s", request.PublisherId) updateOne := mapper.ApiUpdatePublisherToUpdateFields(request.PublisherId, request.Body, s.Client) @@ -222,34 +198,13 @@ func (s *DripStrictServerImplementation) UpdatePublisher( func (s *DripStrictServerImplementation) CreateNode( ctx context.Context, request drip.CreateNodeRequestObject) (drip.CreateNodeResponseObject, error) { log.Ctx(ctx).Info().Msgf("CreateNode called with publisher ID: %s", request.PublisherId) - userId, err := mapper.GetUserIDFromContext(ctx) - if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.CreateNode400JSONResponse{Message: "Invalid user ID"}, err - } - log.Ctx(ctx).Info().Msgf( - "Checking if user ID %s has permission to create node for publisher ID %s", userId, request.PublisherId) - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.CreateNode400JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.CreateNode401Response{}, err - - case err != nil: + node, err := s.RegistryService.CreateNode(ctx, s.Client, request.PublisherId, request.Body) + if mapper.IsErrorBadRequest(err) || ent.IsConstraintError(err) { log.Ctx(ctx).Error().Msgf( - "Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.CreateNode500JSONResponse{ - Message: "Failed to assert publisher permission", Error: err.Error()}, err + "Failed to create node for publisher ID %s w/ err: %v", request.PublisherId, err) + return drip.CreateNode400JSONResponse{Message: "The node already exists", Error: err.Error()}, err } - - node, err := s.RegistryService.CreateNode(ctx, s.Client, request.PublisherId, request.Body) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to create node for publisher ID %s w/ err: %v", request.PublisherId, err) return drip.CreateNode500JSONResponse{Message: "Failed to create node", Error: err.Error()}, err @@ -290,6 +245,7 @@ func (s *DripStrictServerImplementation) ListNodesForPublisher( func (s *DripStrictServerImplementation) ListAllNodes( ctx context.Context, request drip.ListAllNodesRequestObject) (drip.ListAllNodesResponseObject, error) { + log.Ctx(ctx).Info().Msg("ListAllNodes request received") // Set default values for pagination parameters @@ -302,13 +258,20 @@ func (s *DripStrictServerImplementation) ListAllNodes( limit = *request.Params.Limit } + // Initialize the node filter + filter := &drip_services.NodeFilter{} + if request.Params.IncludeBanned != nil { + filter.IncludeBanned = *request.Params.IncludeBanned + } + // List nodes from the registry service - nodeResults, err := s.RegistryService.ListNodes(ctx, s.Client, page, limit, &drip_services.NodeFilter{}) + nodeResults, err := s.RegistryService.ListNodes(ctx, s.Client, page, limit, filter) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to list nodes w/ err: %v", err) return drip.ListAllNodes500JSONResponse{Message: "Failed to list nodes", Error: err.Error()}, err } + // Handle case when no nodes are found if len(nodeResults.Nodes) == 0 { log.Ctx(ctx).Info().Msg("No nodes found") return drip.ListAllNodes200JSONResponse{ @@ -320,13 +283,17 @@ func (s *DripStrictServerImplementation) ListAllNodes( }, nil } + // Convert database nodes to API nodes apiNodes := make([]drip.Node, 0, len(nodeResults.Nodes)) for _, dbNode := range nodeResults.Nodes { apiNode := mapper.DbNodeToApiNode(dbNode) - if dbNode.Edges.Versions != nil && len(dbNode.Edges.Versions) > 0 { - latestVersion := dbNode.Edges.Versions[0] - apiNode.LatestVersion = mapper.DbNodeVersionToApiNodeVersion(latestVersion) + + // attach information of latest version if available + if len(dbNode.Edges.Versions) > 0 { + apiNode.LatestVersion = mapper.DbNodeVersionToApiNodeVersion(dbNode.Edges.Versions[0]) } + + // Map publisher information apiNode.Publisher = mapper.DbPublisherToApiPublisher(dbNode.Edges.Publisher, false) apiNodes = append(apiNodes, *apiNode) } @@ -341,51 +308,77 @@ func (s *DripStrictServerImplementation) ListAllNodes( }, nil } -func (s *DripStrictServerImplementation) DeleteNode( - ctx context.Context, request drip.DeleteNodeRequestObject) (drip.DeleteNodeResponseObject, error) { +// SearchNodes implements drip.StrictServerInterface. +func (s *DripStrictServerImplementation) SearchNodes(ctx context.Context, request drip.SearchNodesRequestObject) (drip.SearchNodesResponseObject, error) { + log.Ctx(ctx).Info().Msg("SearchNodes request received") - log.Ctx(ctx).Info().Msgf("DeleteNode request received for node ID: %s", request.NodeId) + // Set default values for pagination parameters + page := 1 + if request.Params.Page != nil { + page = *request.Params.Page + } + limit := 10 + if request.Params.Limit != nil { + limit = *request.Params.Limit + } - userId, err := mapper.GetUserIDFromContext(ctx) + f := &drip_services.NodeFilter{} + if request.Params.Search != nil { + f.Search = *request.Params.Search + } + if request.Params.IncludeBanned != nil { + f.IncludeBanned = *request.Params.IncludeBanned + } + // List nodes from the registry service + nodeResults, err := s.RegistryService.ListNodes(ctx, s.Client, page, limit, f) if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.DeleteNode404JSONResponse{Message: "Invalid user ID"}, err + log.Ctx(ctx).Error().Msgf("Failed to search nodes w/ err: %v", err) + return drip.SearchNodes500JSONResponse{Message: "Failed to search nodes", Error: err.Error()}, err } - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.DeleteNode404JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.DeleteNode403JSONResponse{}, err + if len(nodeResults.Nodes) == 0 { + log.Ctx(ctx).Info().Msg("No nodes found") + return drip.SearchNodes200JSONResponse{ + Nodes: &[]drip.Node{}, + Total: &nodeResults.Total, + Page: &nodeResults.Page, + Limit: &nodeResults.Limit, + TotalPages: &nodeResults.TotalPages, + }, nil + } - case err != nil: - log.Ctx(ctx).Error().Msgf("Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.DeleteNode500JSONResponse{Message: "Failed to assert publisher permission", Error: err.Error()}, err + apiNodes := make([]drip.Node, 0, len(nodeResults.Nodes)) + for _, dbNode := range nodeResults.Nodes { + apiNode := mapper.DbNodeToApiNode(dbNode) + if dbNode.Edges.Versions != nil && len(dbNode.Edges.Versions) > 0 { + latestVersion, err := s.RegistryService.GetLatestNodeVersion(ctx, s.Client, dbNode.ID) + if err == nil { + apiNode.LatestVersion = mapper.DbNodeVersionToApiNodeVersion(latestVersion) + } else { + log.Ctx(ctx).Error().Msgf("Failed to get latest version for node %s w/ err: %v", dbNode.ID, err) + } + } + apiNode.Publisher = mapper.DbPublisherToApiPublisher(dbNode.Edges.Publisher, false) + apiNodes = append(apiNodes, *apiNode) } - err = s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, request.NodeId) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.DeleteNode404JSONResponse{Message: "Publisher not found"}, nil + log.Ctx(ctx).Info().Msgf("Found %d nodes", len(apiNodes)) + return drip.SearchNodes200JSONResponse{ + Nodes: &apiNodes, + Total: &nodeResults.Total, + Page: &nodeResults.Page, + Limit: &nodeResults.Limit, + TotalPages: &nodeResults.TotalPages, + }, nil +} - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.DeleteNode403JSONResponse{}, err +func (s *DripStrictServerImplementation) DeleteNode( + ctx context.Context, request drip.DeleteNodeRequestObject) (drip.DeleteNodeResponseObject, error) { - case err != nil: - return drip.DeleteNode500JSONResponse{Message: "Failed to assert publisher permission"}, err - } + log.Ctx(ctx).Info().Msgf("DeleteNode request received for node ID: %s", request.NodeId) - err = s.RegistryService.DeleteNode(ctx, s.Client, request.NodeId) - if err != nil { + err := s.RegistryService.DeleteNode(ctx, s.Client, request.NodeId) + if err != nil && !ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Failed to delete node %s w/ err: %v", request.NodeId, err) return drip.DeleteNode500JSONResponse{Message: "Internal server error"}, err } @@ -427,41 +420,14 @@ func (s *DripStrictServerImplementation) UpdateNode( log.Ctx(ctx).Info().Msgf("UpdateNode request received for node ID: %s", request.NodeId) - userId, err := mapper.GetUserIDFromContext(ctx) - if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.UpdateNode404JSONResponse{Message: "Invalid user ID"}, err + updateOneFunc := func(client *ent.Client) *ent.NodeUpdateOne { + return mapper.ApiUpdateNodeToUpdateFields(request.NodeId, request.Body, client) } - - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.UpdateNode404JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.UpdateNode403JSONResponse{}, err - - case err != nil: - log.Ctx(ctx).Error().Msgf("Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.UpdateNode500JSONResponse{Message: "Failed to assert publisher permission", Error: err.Error()}, err - } - - err = s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, request.NodeId) + updatedNode, err := s.RegistryService.UpdateNode(ctx, s.Client, updateOneFunc) if ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Node %s not found w/ err: %v", request.NodeId, err) return drip.UpdateNode404JSONResponse{Message: "Not Found"}, nil - } else if err != nil { - log.Ctx(ctx).Error().Msgf("Node %s does not belong to publisher "+ - "%s w/ err: %v", request.NodeId, request.PublisherId, err) - return drip.UpdateNode403JSONResponse{Message: "Forbidden"}, err } - - updateOne := mapper.ApiUpdateNodeToUpdateFields(request.NodeId, request.Body, s.Client) - updatedNode, err := s.RegistryService.UpdateNode(ctx, s.Client, updateOne) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to update node %s w/ err: %v", request.NodeId, err) return drip.UpdateNode500JSONResponse{Message: "Failed to update node", Error: err.Error()}, err @@ -473,15 +439,19 @@ func (s *DripStrictServerImplementation) UpdateNode( func (s *DripStrictServerImplementation) ListNodeVersions( ctx context.Context, request drip.ListNodeVersionsRequestObject) (drip.ListNodeVersionsResponseObject, error) { - log.Ctx(ctx).Info().Msgf("ListNodeVersions request received for node ID: %s", request.NodeId) - nodeVersions, err := s.RegistryService.ListNodeVersions(ctx, s.Client, request.NodeId) + apiStatus := mapper.ApiNodeVersionStatusesToDbNodeVersionStatuses(request.Params.Statuses) + + nodeVersionsResult, err := s.RegistryService.ListNodeVersions(ctx, s.Client, &drip_services.NodeVersionFilter{ + NodeId: request.NodeId, + Status: apiStatus, + }) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to list node versions for node %s w/ err: %v", request.NodeId, err) return drip.ListNodeVersions500JSONResponse{Message: "Failed to list node versions", Error: err.Error()}, err } - + nodeVersions := nodeVersionsResult.NodeVersions apiNodeVersions := make([]drip.NodeVersion, 0, len(nodeVersions)) for _, dbNodeVersion := range nodeVersions { apiNodeVersions = append(apiNodeVersions, *mapper.DbNodeVersionToApiNodeVersion(dbNodeVersion)) @@ -495,18 +465,6 @@ func (s *DripStrictServerImplementation) PublishNodeVersion( ctx context.Context, request drip.PublishNodeVersionRequestObject) (drip.PublishNodeVersionResponseObject, error) { log.Ctx(ctx).Info().Msgf("PublishNodeVersion request received for node ID: %s", request.NodeId) - tokenValid, err := s.RegistryService.IsPersonalAccessTokenValidForPublisher( - ctx, s.Client, request.PublisherId, request.Body.PersonalAccessToken) - if err != nil { - log.Ctx(ctx).Error().Msgf("Token validation failed w/ err: %v", err) - return drip.PublishNodeVersion400JSONResponse{Message: "Failed to validate token", Error: err.Error()}, nil - } - if !tokenValid { - errMessage := "Invalid personal access token" - log.Ctx(ctx).Error().Msg(errMessage) - return drip.PublishNodeVersion400JSONResponse{Message: errMessage}, nil - } - // Check if node exists, create if not node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) if err != nil && !ent.IsNotFound(err) { @@ -515,6 +473,10 @@ func (s *DripStrictServerImplementation) PublishNodeVersion( return drip.PublishNodeVersion500JSONResponse{}, err } else if err != nil { node, err = s.RegistryService.CreateNode(ctx, s.Client, request.PublisherId, &request.Body.Node) + if mapper.IsErrorBadRequest(err) || ent.IsConstraintError(err) { + log.Ctx(ctx).Error().Msgf("Node creation failed w/ err: %v", err) + return drip.PublishNodeVersion400JSONResponse{Message: "Failed to create node", Error: err.Error()}, nil + } if err != nil { log.Ctx(ctx).Error().Msgf("Node creation failed w/ err: %v", err) return drip.PublishNodeVersion500JSONResponse{Message: "Failed to create node", Error: err.Error()}, nil @@ -524,14 +486,10 @@ func (s *DripStrictServerImplementation) PublishNodeVersion( } else { // TODO(james): distinguish between not found vs. nodes that belong to other publishers // If node already exists, validate ownership - err = s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, node.ID) - if err != nil { - errMessage := "Node does not belong to Publisher." - log.Ctx(ctx).Error().Msgf("Node ownership validation failed w/ err: %v", err) - return drip.PublishNodeVersion403JSONResponse{Message: errMessage}, err + updateOneFunc := func(client *ent.Client) *ent.NodeUpdateOne { + return mapper.ApiUpdateNodeToUpdateFields(node.ID, &request.Body.Node, s.Client) } - updateOne := mapper.ApiUpdateNodeToUpdateFields(node.ID, &request.Body.Node, s.Client) - _, err = s.RegistryService.UpdateNode(ctx, s.Client, updateOne) + _, err = s.RegistryService.UpdateNode(ctx, s.Client, updateOneFunc) if err != nil { errMessage := "Failed to update node: " + err.Error() log.Ctx(ctx).Error().Msgf("Node update failed w/ err: %v", err) @@ -543,6 +501,9 @@ func (s *DripStrictServerImplementation) PublishNodeVersion( // Create node version nodeVersionCreation, err := s.RegistryService.CreateNodeVersion(ctx, s.Client, request.PublisherId, node.ID, &request.Body.NodeVersion) if err != nil { + if ent.IsConstraintError(err) { + return drip.PublishNodeVersion400JSONResponse{Message: "The node version already exists"}, nil + } errMessage := "Failed to create node version: " + err.Error() log.Ctx(ctx).Error().Msgf("Node version creation failed w/ err: %v", err) return drip.PublishNodeVersion400JSONResponse{Message: errMessage}, err @@ -562,51 +523,13 @@ func (s *DripStrictServerImplementation) UpdateNodeVersion( log.Ctx(ctx).Info().Msgf("UpdateNodeVersion request received for node ID: "+ "%s, version ID: %s", request.NodeId, request.VersionId) - // Retrieve user ID from context - userId, err := mapper.GetUserIDFromContext(ctx) - if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.UpdateNodeVersion404JSONResponse{Message: "Invalid user ID"}, err - } - - // Assert publisher permissions - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.UpdateNodeVersion404JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.UpdateNodeVersion403JSONResponse{}, err - - case err != nil: - log.Ctx(ctx).Error().Msgf("Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.UpdateNodeVersion500JSONResponse{Message: "Failed to assert publisher permission", Error: err.Error()}, err - } - - // Assert node belongs to publisher - err = s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, request.NodeId) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.UpdateNodeVersion404JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - errMessage := "Node does not belong to Publisher." - log.Ctx(ctx).Error().Msgf("Node ownership validation failed w/ err: %v", err) - return drip.UpdateNodeVersion404JSONResponse{Message: errMessage}, err - - case err != nil: - log.Ctx(ctx).Error().Msgf("Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.UpdateNodeVersion500JSONResponse{Message: "Failed to assert publisher permission", Error: err.Error()}, err - } - // Update node version updateOne := mapper.ApiUpdateNodeVersionToUpdateFields(request.VersionId, request.Body, s.Client) version, err := s.RegistryService.UpdateNodeVersion(ctx, s.Client, updateOne) + if ent.IsNotFound(err) { + log.Ctx(ctx).Error().Msgf("Node %s or it's version not found w/ err: %v", request.NodeId, err) + return drip.UpdateNodeVersion404JSONResponse{Message: "Not Found"}, nil + } if err != nil { errMessage := "Failed to update node version" log.Ctx(ctx).Error().Msgf("Node version update failed w/ err: %v", err) @@ -620,6 +543,34 @@ func (s *DripStrictServerImplementation) UpdateNodeVersion( }, nil } +// PostNodeVersionReview implements drip.StrictServerInterface. +func (s *DripStrictServerImplementation) PostNodeReview(ctx context.Context, request drip.PostNodeReviewRequestObject) (drip.PostNodeReviewResponseObject, error) { + log.Ctx(ctx).Info().Msgf("PostNodeReview request received for "+ + "node ID: %s", request.NodeId) + + if request.Params.Star < 1 || request.Params.Star > 5 { + log.Ctx(ctx).Error().Msgf("Invalid star received: %d", request.Params.Star) + return drip.PostNodeReview400Response{}, nil + } + + userId, err := mapper.GetUserIDFromContext(ctx) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) + return drip.PostNodeReview404JSONResponse{}, err + } + + nv, err := s.RegistryService.AddNodeReview(ctx, s.Client, request.NodeId, userId, request.Params.Star) + if ent.IsNotFound(err) { + log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err) + return drip.PostNodeReview404JSONResponse{}, nil + } + + node := mapper.DbNodeToApiNode(nv) + log.Ctx(ctx).Info().Msgf("Node review for %s stored successfully", request.NodeId) + return drip.PostNodeReview200JSONResponse(*node), nil + +} + func (s *DripStrictServerImplementation) DeleteNodeVersion( ctx context.Context, request drip.DeleteNodeVersionRequestObject) (drip.DeleteNodeVersionResponseObject, error) { log.Ctx(ctx).Info().Msgf("DeleteNodeVersion request received for node ID: "+ @@ -638,7 +589,7 @@ func (s *DripStrictServerImplementation) GetNodeVersion( log.Ctx(ctx).Info().Msgf("GetNodeVersion request received for "+ "node ID: %s, version ID: %s", request.NodeId, request.VersionId) - nodeVersion, err := s.RegistryService.GetNodeVersion(ctx, s.Client, request.NodeId, request.VersionId) + nodeVersion, err := s.RegistryService.GetNodeVersionByVersion(ctx, s.Client, request.NodeId, request.VersionId) if ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err) return drip.GetNodeVersion404JSONResponse{}, nil @@ -661,22 +612,6 @@ func (s *DripStrictServerImplementation) ListPersonalAccessTokens( ctx context.Context, request drip.ListPersonalAccessTokensRequestObject) (drip.ListPersonalAccessTokensResponseObject, error) { log.Ctx(ctx).Info().Msgf("ListPersonalAccessTokens request received for publisher ID: %s", request.PublisherId) - // Retrieve user ID from context - userId, err := mapper.GetUserIDFromContext(ctx) - if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.ListPersonalAccessTokens404JSONResponse{Message: "Invalid user ID"}, err - } - - // Assert publisher permissions - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - if err != nil { - errMessage := "User does not have the necessary permissions: " + err.Error() - log.Ctx(ctx).Error().Msgf("Permission assertion failed w/ err: %v", err) - return drip.ListPersonalAccessTokens403JSONResponse{Message: errMessage}, err - } - // List personal access tokens personalAccessTokens, err := s.RegistryService.ListPersonalAccessTokens(ctx, s.Client, request.PublisherId) if err != nil { @@ -701,31 +636,6 @@ func (s *DripStrictServerImplementation) CreatePersonalAccessToken( log.Ctx(ctx).Info().Msgf("CreatePersonalAccessToken request received "+ "for publisher ID: %s", request.PublisherId) - // Retrieve user ID from context - userId, err := mapper.GetUserIDFromContext(ctx) - if err != nil { - log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.CreatePersonalAccessToken400JSONResponse{Message: "Invalid user ID"}, err - } - - // Assert publisher permissions - err = s.RegistryService.AssertPublisherPermissions(ctx, s.Client, request.PublisherId, - userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.CreatePersonalAccessToken400JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.CreatePersonalAccessToken403JSONResponse{}, err - - case err != nil: - log.Ctx(ctx).Error().Msgf("Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.CreatePersonalAccessToken500JSONResponse{Message: "Failed to assert publisher permission", Error: err.Error()}, err - } - // Create personal access token description := "" if request.Body.Description != nil { @@ -759,24 +669,6 @@ func (s *DripStrictServerImplementation) DeletePersonalAccessToken( return drip.DeletePersonalAccessToken404JSONResponse{Message: "Invalid user ID"}, err } - // Assert publisher permissions - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) - switch { - case ent.IsNotFound(err): - log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", request.PublisherId) - return drip.DeletePersonalAccessToken404JSONResponse{Message: "Publisher not found"}, nil - - case drip_services.IsPermissionError(err): - log.Ctx(ctx).Error().Msgf( - "Permission denied for user ID %s on publisher ID %s w/ err: %v", userId, request.PublisherId, err) - return drip.DeletePersonalAccessToken403JSONResponse{}, err - - case err != nil: - log.Ctx(ctx).Error().Msgf("Failed to assert publisher permission %s w/ err: %v", request.PublisherId, err) - return drip.DeletePersonalAccessToken500JSONResponse{Message: "Failed to assert publisher permission", Error: err.Error()}, err - } - // Assert access token belongs to publisher err = s.RegistryService.AssertAccessTokenBelongsToPublisher(ctx, s.Client, request.PublisherId, uuid.MustParse(request.TokenId)) switch { @@ -813,7 +705,7 @@ func (s *DripStrictServerImplementation) InstallNode( log.Ctx(ctx).Info().Msgf("InstallNode request received for node ID: %s", request.NodeId) // Get node - _, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) + node, err := s.RegistryService.GetNode(ctx, s.Client, request.NodeId) if ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Error retrieving node w/ err: %v", err) return drip.InstallNode404JSONResponse{Message: "Node not found"}, nil @@ -835,6 +727,12 @@ func (s *DripStrictServerImplementation) InstallNode( log.Ctx(ctx).Error().Msgf("Error retrieving latest node version w/ err: %v", err) return drip.InstallNode500JSONResponse{Message: errMessage}, err } + _, err = s.RegistryService.RecordNodeInstalation(ctx, s.Client, node) + if err != nil { + errMessage := "Failed to get increment number of node version install: " + err.Error() + log.Ctx(ctx).Error().Msgf("Error incrementing number of latest node version install w/ err: %v", err) + return drip.InstallNode500JSONResponse{Message: errMessage}, err + } mp.Track(ctx, []*mixpanel.Event{ mp.NewEvent("Install Node Latest", "", map[string]any{ "Node ID": request.NodeId, @@ -845,7 +743,7 @@ func (s *DripStrictServerImplementation) InstallNode( *mapper.DbNodeVersionToApiNodeVersion(nodeVersion), ), nil } else { - nodeVersion, err := s.RegistryService.GetNodeVersion(ctx, s.Client, request.NodeId, *request.Params.Version) + nodeVersion, err := s.RegistryService.GetNodeVersionByVersion(ctx, s.Client, request.NodeId, *request.Params.Version) if ent.IsNotFound(err) { log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err) return drip.InstallNode404JSONResponse{Message: "Not found"}, nil @@ -855,6 +753,12 @@ func (s *DripStrictServerImplementation) InstallNode( log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err) return drip.InstallNode500JSONResponse{Message: errMessage}, err } + _, err = s.RegistryService.RecordNodeInstalation(ctx, s.Client, node) + if err != nil { + errMessage := "Failed to get increment number of node version install: " + err.Error() + log.Ctx(ctx).Error().Msgf("Error incrementing number of latest node version install w/ err: %v", err) + return drip.InstallNode500JSONResponse{Message: errMessage}, err + } mp.Track(ctx, []*mixpanel.Event{ mp.NewEvent("Install Node", "", map[string]any{ "Node ID": request.NodeId, @@ -869,38 +773,235 @@ func (s *DripStrictServerImplementation) InstallNode( func (s *DripStrictServerImplementation) GetPermissionOnPublisherNodes( ctx context.Context, request drip.GetPermissionOnPublisherNodesRequestObject) (drip.GetPermissionOnPublisherNodesResponseObject, error) { - userId, err := mapper.GetUserIDFromContext(ctx) + + err := s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, request.NodeId) if err != nil { return drip.GetPermissionOnPublisherNodes200JSONResponse{CanEdit: proto.Bool(false)}, nil } - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) + return drip.GetPermissionOnPublisherNodes200JSONResponse{CanEdit: proto.Bool(true)}, nil +} + +func (s *DripStrictServerImplementation) GetPermissionOnPublisher( + ctx context.Context, request drip.GetPermissionOnPublisherRequestObject) (drip.GetPermissionOnPublisherResponseObject, error) { + + return drip.GetPermissionOnPublisher200JSONResponse{CanEdit: proto.Bool(true)}, nil +} + +// BanPublisher implements drip.StrictServerInterface. +func (s *DripStrictServerImplementation) BanPublisher(ctx context.Context, request drip.BanPublisherRequestObject) (drip.BanPublisherResponseObject, error) { + userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { - return drip.GetPermissionOnPublisherNodes200JSONResponse{CanEdit: proto.Bool(false)}, nil + log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) + return drip.BanPublisher401Response{}, nil + } + user, err := s.Client.User.Get(ctx, userId) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) + return drip.BanPublisher401Response{}, nil + } + if !user.IsAdmin { + log.Ctx(ctx).Error().Msgf("User is not admin w/ err") + return drip.BanPublisher403JSONResponse{ + Message: "User is not admin", + }, nil } - err = s.RegistryService.AssertNodeBelongsToPublisher(ctx, s.Client, request.PublisherId, request.NodeId) + err = s.RegistryService.BanPublisher(ctx, s.Client, request.PublisherId) + if ent.IsNotFound(err) { + log.Ctx(ctx).Error().Msgf("Publisher '%s' not found w/ err: %v", request.PublisherId, err) + return drip.BanPublisher404JSONResponse{ + Message: "Publisher not found", + }, nil + } if err != nil { - return drip.GetPermissionOnPublisherNodes200JSONResponse{CanEdit: proto.Bool(false)}, nil + log.Ctx(ctx).Error().Msgf("Error banning publisher w/ err: %v", err) + return drip.BanPublisher500JSONResponse{ + Message: "Error banning publisher", + Error: err.Error(), + }, nil } + return drip.BanPublisher204Response{}, nil +} + +// BanPublisherNode implements drip.StrictServerInterface. +func (s *DripStrictServerImplementation) BanPublisherNode(ctx context.Context, request drip.BanPublisherNodeRequestObject) (drip.BanPublisherNodeResponseObject, error) { + userId, err := mapper.GetUserIDFromContext(ctx) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) + return drip.BanPublisherNode401Response{}, nil + } + user, err := s.Client.User.Get(ctx, userId) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) + return drip.BanPublisherNode401Response{}, nil + } + if !user.IsAdmin { + log.Ctx(ctx).Error().Msgf("User is not admin w/ err") + return drip.BanPublisherNode403JSONResponse{ + Message: "User is not admin", + }, nil + } + + err = s.RegistryService.BanNode(ctx, s.Client, request.PublisherId, request.NodeId) + if ent.IsNotFound(err) { + log.Ctx(ctx).Error().Msgf("Publisher '%s' or node '%s' not found w/ err: %v", request.PublisherId, request.NodeId, err) + return drip.BanPublisherNode404JSONResponse{ + Message: "Publisher or Node not found", + }, nil + } + if err != nil { + log.Ctx(ctx).Error().Msgf("Error banning node w/ err: %v", err) + return drip.BanPublisherNode500JSONResponse{ + Message: "Error banning node", + Error: err.Error(), + }, nil + } + return drip.BanPublisherNode204Response{}, nil - return drip.GetPermissionOnPublisherNodes200JSONResponse{CanEdit: proto.Bool(true)}, nil } -func (s *DripStrictServerImplementation) GetPermissionOnPublisher( - ctx context.Context, request drip.GetPermissionOnPublisherRequestObject) (drip.GetPermissionOnPublisherResponseObject, error) { +func (s *DripStrictServerImplementation) AdminUpdateNodeVersion( + ctx context.Context, request drip.AdminUpdateNodeVersionRequestObject) (drip.AdminUpdateNodeVersionResponseObject, error) { userId, err := mapper.GetUserIDFromContext(ctx) if err != nil { log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) - return drip.GetPermissionOnPublisher200JSONResponse{CanEdit: proto.Bool(false)}, err + return drip.AdminUpdateNodeVersion401Response{}, nil + } + user, err := s.Client.User.Get(ctx, userId) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to get user ID from context w/ err: %v", err) + return drip.AdminUpdateNodeVersion401Response{}, nil + } + if !user.IsAdmin { + log.Ctx(ctx).Error().Msgf("User is not admin w/ err") + return drip.AdminUpdateNodeVersion403JSONResponse{ + Message: "User is not admin", + }, nil + } + + nodeVersion, err := s.RegistryService.GetNodeVersionByVersion(ctx, s.Client, request.NodeId, request.VersionNumber) + if err != nil { + log.Ctx(ctx).Error().Msgf("Error retrieving node version w/ err: %v", err) + if ent.IsNotFound(err) { + return drip.AdminUpdateNodeVersion404JSONResponse{}, nil + } + return drip.AdminUpdateNodeVersion500JSONResponse{}, err } - err = s.RegistryService.AssertPublisherPermissions( - ctx, s.Client, request.PublisherId, userId, []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}) + dbNodeVersion := mapper.ApiNodeVersionStatusToDbNodeVersionStatus(*request.Body.Status) + statusReason := "" + if request.Body.StatusReason != nil { + statusReason = *request.Body.StatusReason + } + err = nodeVersion.Update().SetStatus(dbNodeVersion).SetStatusReason(statusReason).Exec(ctx) if err != nil { - return drip.GetPermissionOnPublisher200JSONResponse{CanEdit: proto.Bool(false)}, nil + log.Ctx(ctx).Error().Msgf("Failed to update node version w/ err: %v", err) + return drip.AdminUpdateNodeVersion500JSONResponse{}, err } - return drip.GetPermissionOnPublisher200JSONResponse{CanEdit: proto.Bool(true)}, nil + log.Ctx(ctx).Info().Msgf("Node version %s updated successfully", request.VersionNumber) + return drip.AdminUpdateNodeVersion200JSONResponse{ + Status: request.Body.Status, + }, nil +} + +func (s *DripStrictServerImplementation) SecurityScan( + ctx context.Context, request drip.SecurityScanRequestObject) (drip.SecurityScanResponseObject, error) { + minAge := 30 * time.Minute + if request.Params.MinAge != nil { + minAge = *request.Params.MinAge + } + maxNodes := 50 + if request.Params.MaxNodes != nil { + maxNodes = *request.Params.MaxNodes + } + nodeVersionsResult, err := s.RegistryService.ListNodeVersions(ctx, s.Client, &drip_services.NodeVersionFilter{ + Status: []schema.NodeVersionStatus{schema.NodeVersionStatusPending}, + MinAge: minAge, + PageSize: maxNodes, + Page: 1, + }) + nodeVersions := nodeVersionsResult.NodeVersions + + log.Ctx(ctx).Info().Msgf("Found %d node versions to scan", len(nodeVersions)) + + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to list node versions w/ err: %v", err) + return drip.SecurityScan500JSONResponse{}, err + } + + for _, nodeVersion := range nodeVersions { + err := s.RegistryService.PerformSecurityCheck(ctx, s.Client, nodeVersion) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to perform security scan w/ err: %v", err) + } + } + return drip.SecurityScan200Response{}, nil +} + +func (s *DripStrictServerImplementation) ListAllNodeVersions( + ctx context.Context, request drip.ListAllNodeVersionsRequestObject) (drip.ListAllNodeVersionsResponseObject, error) { + log.Ctx(ctx).Info().Msg("ListAllNodeVersions request received") + + page := 1 + if request.Params.Page != nil { + page = *request.Params.Page + } + pageSize := 10 + if request.Params.PageSize != nil && *request.Params.PageSize < 100 { + pageSize = *request.Params.PageSize + } + f := &drip_services.NodeVersionFilter{ + Page: page, + PageSize: pageSize, + } + + if request.Params.Statuses != nil { + f.Status = mapper.ApiNodeVersionStatusesToDbNodeVersionStatuses(request.Params.Statuses) + } + + // List nodes from the registry service + nodeVersionResults, err := s.RegistryService.ListNodeVersions(ctx, s.Client, f) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to list node versions w/ err: %v", err) + return drip.ListAllNodeVersions500JSONResponse{Message: "Failed to list node versions", Error: err.Error()}, nil + } + + if len(nodeVersionResults.NodeVersions) == 0 { + log.Ctx(ctx).Info().Msg("No node versions found") + return drip.ListAllNodeVersions200JSONResponse{ + Versions: &[]drip.NodeVersion{}, + Total: &nodeVersionResults.Total, + Page: &nodeVersionResults.Page, + PageSize: &nodeVersionResults.Limit, + TotalPages: &nodeVersionResults.TotalPages, + }, nil + } + + apiNodeVersions := make([]drip.NodeVersion, 0, len(nodeVersionResults.NodeVersions)) + for _, dbNodeVersion := range nodeVersionResults.NodeVersions { + apiNodeVersions = append(apiNodeVersions, *mapper.DbNodeVersionToApiNodeVersion(dbNodeVersion)) + } + + log.Ctx(ctx).Info().Msgf("Found %d node versions", nodeVersionResults.Total) + return drip.ListAllNodeVersions200JSONResponse{ + Versions: &apiNodeVersions, + Total: &nodeVersionResults.Total, + Page: &nodeVersionResults.Page, + PageSize: &nodeVersionResults.Limit, + TotalPages: &nodeVersionResults.TotalPages, + }, nil +} + +func (s *DripStrictServerImplementation) ReindexNodes(ctx context.Context, request drip.ReindexNodesRequestObject) (res drip.ReindexNodesResponseObject, err error) { + log.Ctx(ctx).Info().Msg("ReindexNodes request received") + err = s.RegistryService.ReindexAllNodes(ctx, s.Client) + if err != nil { + log.Ctx(ctx).Error().Msgf("Failed to list node versions w/ err: %v", err) + return drip.ReindexNodes500JSONResponse{Message: "Failed to reindex nodes", Error: err.Error()}, nil + } + + log.Ctx(ctx).Info().Msgf("Reindex nodes successful") + return drip.ReindexNodes200Response{}, nil } diff --git a/server/middleware/firebase_auth.go b/server/middleware/authentication/firebase_auth.go similarity index 78% rename from server/middleware/firebase_auth.go rename to server/middleware/authentication/firebase_auth.go index 8d75a25..a97e4d0 100644 --- a/server/middleware/firebase_auth.go +++ b/server/middleware/authentication/firebase_auth.go @@ -1,4 +1,4 @@ -package drip_middleware +package drip_authentication import ( "context" @@ -15,28 +15,36 @@ import ( "github.com/labstack/echo/v4" ) -// TODO(robinhuang): Have this middleware only validate and extract the user details. Move all authorization logic to another middleware. -func FirebaseMiddleware(entClient *ent.Client) echo.MiddlewareFunc { +// FirebaseAuthMiddleware validates and extracts user details from the Firebase token. +// Certain endpoints are allow-listed and bypass this middleware. +func FirebaseAuthMiddleware(entClient *ent.Client) echo.MiddlewareFunc { // Handlers in here should bypass this middleware. var allowlist = map[*regexp.Regexp][]string{ regexp.MustCompile(`^/openapi$`): {"GET"}, + regexp.MustCompile(`^/security-scan$`): {"GET"}, regexp.MustCompile(`^/users/sessions$`): {"DELETE"}, regexp.MustCompile(`^/vm$`): {"ANY"}, regexp.MustCompile(`^/health$`): {"GET"}, regexp.MustCompile(`^/upload-artifact$`): {"POST"}, regexp.MustCompile(`^/gitcommit$`): {"POST", "GET"}, + regexp.MustCompile(`^/workflowresult/[^/]+$`): {"GET"}, regexp.MustCompile(`^/branch$`): {"GET"}, regexp.MustCompile(`^/publishers/[^/]+/nodes/[^/]+/versions$`): {"POST"}, regexp.MustCompile(`^/publishers/[^/]+/nodes$`): {"GET"}, regexp.MustCompile(`^/publishers/[^/]+$`): {"GET"}, regexp.MustCompile(`^/nodes$`): {"GET"}, + regexp.MustCompile(`^/versions$`): {"GET"}, regexp.MustCompile(`^/nodes/[^/]+$`): {"GET"}, regexp.MustCompile(`^/nodes/[^/]+/versions$`): {"GET"}, regexp.MustCompile(`^/nodes/[^/]+/install$`): {"GET"}, + regexp.MustCompile(`^/nodes/reindex$`): {"POST"}, + regexp.MustCompile(`^/publishers/[^/]+/ban$`): {"POST"}, + regexp.MustCompile(`^/publishers/[^/]+/nodes/[^/]+/ban$`): {"POST"}, } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx echo.Context) error { - // Check if the request is in the allowlist. + // Check if the request is in the allow list. reqPath := ctx.Request().URL.Path reqMethod := ctx.Request().Method for basePathRegex, methods := range allowlist { @@ -54,7 +62,6 @@ func FirebaseMiddleware(entClient *ent.Client) echo.MiddlewareFunc { // If header is present, extract the token and verify it. header := ctx.Request().Header.Get("Authorization") if header != "" { - // Extract the JWT token from the header splitToken := strings.Split(header, "Bearer ") if len(splitToken) != 2 { @@ -83,12 +90,16 @@ func FirebaseMiddleware(entClient *ent.Client) echo.MiddlewareFunc { userDetails := extractUserDetails(token) log.Ctx(ctx.Request().Context()).Debug().Msg("Authenticated user " + userDetails.Email) - authdCtx := context.WithValue(ctx.Request().Context(), UserContextKey, userDetails) - ctx.SetRequest(ctx.Request().WithContext(authdCtx)) - newUserError := db.UpsertUser(ctx.Request().Context(), entClient, token.UID, userDetails.Email, userDetails.Name) + + authContext := context.WithValue(ctx.Request().Context(), UserContextKey, userDetails) + ctx.SetRequest(ctx.Request().WithContext(authContext)) + + newUserError := db.UpsertUser( + ctx.Request().Context(), entClient, token.UID, userDetails.Email, userDetails.Name) if newUserError != nil { log.Ctx(ctx.Request().Context()).Error().Err(newUserError).Msg("error User upserted successfully.") } + return next(ctx) } diff --git a/server/middleware/firebase_auth_test.go b/server/middleware/authentication/firebase_auth_test.go similarity index 92% rename from server/middleware/firebase_auth_test.go rename to server/middleware/authentication/firebase_auth_test.go index 715b941..556908f 100644 --- a/server/middleware/firebase_auth_test.go +++ b/server/middleware/authentication/firebase_auth_test.go @@ -1,10 +1,9 @@ -package drip_middleware_test +package drip_authentication import ( "net/http" "net/http/httptest" "registry-backend/ent" - drip_middleware "registry-backend/server/middleware" "testing" "github.com/labstack/echo/v4" @@ -20,7 +19,7 @@ func TestAllowlist(t *testing.T) { // Mock ent.Client mockEntClient := &ent.Client{} - middleware := drip_middleware.FirebaseMiddleware(mockEntClient) + middleware := FirebaseAuthMiddleware(mockEntClient) tests := []struct { name string @@ -40,6 +39,7 @@ func TestAllowlist(t *testing.T) { {"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true}, {"Publisher POST", "/publishers", "POST", false}, {"Unauthorized Path", "/nonexistent", "GET", false}, + {"Reindex Nodes", "/nodes/reindex", "POST", true}, {"Get All Nodes", "/nodes", "GET", true}, {"Install Nodes", "/nodes/node-id/install", "GET", true}, } diff --git a/server/middleware/authentication/jwt_admin_auth.go b/server/middleware/authentication/jwt_admin_auth.go new file mode 100644 index 0000000..0b1b886 --- /dev/null +++ b/server/middleware/authentication/jwt_admin_auth.go @@ -0,0 +1,101 @@ +package drip_authentication + +import ( + "context" + "fmt" + "net/http" + "regexp" + "registry-backend/ent" + "strings" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" +) + +// JWTAdminAuthMiddleware checks for a JWT token in the Authorization header, +// verifies it using the provided secret, and adds user details to the context if valid. +// +// This check is only performed for specific admin protected endpoints. +func JWTAdminAuthMiddleware(entClient *ent.Client, secret string) echo.MiddlewareFunc { + // Key function to validate the JWT token + keyfunc := func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(secret), nil + } + + // Define the regex patterns for the protected endpoints + protectedEndpoints := []*regexp.Regexp{ + regexp.MustCompile(`^/publishers/[^/]+/ban$`), + regexp.MustCompile(`^/publishers/[^/]+/nodes/[^/]+/ban$`), + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + reqPath := c.Request().URL.Path + + // Check if the request path matches any of the protected endpoints + isProtected := false + for _, pattern := range protectedEndpoints { + if pattern.MatchString(reqPath) { + isProtected = true + break + } + } + + if !isProtected { + // If the request is not for a protected endpoint, skip this middleware + return next(c) + } + + // Get the Authorization header + header := c.Request().Header.Get("Authorization") + if header == "" { + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") + } + + // Extract the JWT token from the header + splitToken := strings.Split(header, "Bearer ") + if len(splitToken) != 2 { + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") + } + token := splitToken[1] + + // Parse and validate the JWT token + tokenData, err := jwt.Parse(token, keyfunc) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") + } + + // Extract claims from the token + claims, ok := tokenData.Claims.(jwt.MapClaims) + if !ok || !tokenData.Valid { + return echo.NewHTTPError(http.StatusUnauthorized, "invalid jwt token") + } + + // Get the subject (user ID) from the claims + sub, err := claims.GetSubject() + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "missing sub claim") + } + + // Retrieve the user from the database + user, err := entClient.User.Get(c.Request().Context(), sub) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "invalid user") + } + + // Add user details to the request context + authContext := context.WithValue(c.Request().Context(), UserContextKey, &UserDetails{ + ID: user.ID, + Email: user.Email, + Name: user.Name, + }) + c.SetRequest(c.Request().WithContext(authContext)) + + // Call the next handler + return next(c) + } + } +} diff --git a/server/middleware/authentication/jwt_admin_auth_test.go b/server/middleware/authentication/jwt_admin_auth_test.go new file mode 100644 index 0000000..10c05fe --- /dev/null +++ b/server/middleware/authentication/jwt_admin_auth_test.go @@ -0,0 +1,71 @@ +package drip_authentication + +import ( + "net/http" + "net/http/httptest" + "registry-backend/ent" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestJWTAdminAllowlist(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Mock ent.Client + mockEntClient := &ent.Client{} + + middleware := JWTAdminAuthMiddleware(mockEntClient, "secret") + + tests := []struct { + name string + path string + method string + allowed bool + }{ + {"OpenAPI GET", "/openapi", "GET", true}, + {"Session DELETE", "/users/sessions", "DELETE", true}, + {"Health GET", "/health", "GET", true}, + {"VM ANY", "/vm", "POST", true}, + {"VM ANY GET", "/vm", "GET", true}, + {"Artifact POST", "/upload-artifact", "POST", true}, + {"Git Commit POST", "/gitcommit", "POST", true}, + {"Git Commit GET", "/gitcommit", "GET", true}, + {"Branch GET", "/branch", "GET", true}, + {"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true}, + {"Publisher POST", "/publishers", "POST", true}, + {"Unauthorized Path", "/nonexistent", "GET", true}, + {"Get All Nodes", "/nodes", "GET", true}, + {"Install Nodes", "/nodes/node-id/install", "GET", true}, + + {"Ban Publisher", "/publishers/publisher-id/ban", "POST", false}, + {"Ban Node", "/publishers/publisher-id/nodes/node-id/ban", "POST", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + c.SetRequest(req) + handled := false + next := echo.HandlerFunc(func(c echo.Context) error { + handled = true + return nil + }) + err := middleware(next)(c) + if tt.allowed { + assert.True(t, handled, "Request should be allowed through") + assert.Nil(t, err) + } else { + assert.False(t, handled, "Request should not be allowed through") + assert.NotNil(t, err) + httpError, ok := err.(*echo.HTTPError) + assert.True(t, ok, "Error should be HTTPError") + assert.Equal(t, http.StatusUnauthorized, httpError.Code) + } + }) + } +} diff --git a/server/middleware/authentication/service_account_auth.go b/server/middleware/authentication/service_account_auth.go new file mode 100644 index 0000000..799bb83 --- /dev/null +++ b/server/middleware/authentication/service_account_auth.go @@ -0,0 +1,69 @@ +package drip_authentication + +import ( + "net/http" + "strings" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" + "google.golang.org/api/idtoken" +) + +func ServiceAccountAuthMiddleware() echo.MiddlewareFunc { + // Handlers in here should be checked by this middleware. + var checklist = map[string][]string{ + "/security-scan": {"GET"}, + "/nodes/reindex": {"POST"}, + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(ctx echo.Context) error { + // Check if the request path and method are in the checklist + path := ctx.Request().URL.Path + method := ctx.Request().Method + + methods, ok := checklist[path] + if !ok { + return next(ctx) + } + + for _, m := range methods { + if method == m { + ok = true + break + } + } + if !ok { + return next(ctx) + } + + // validate token + authHeader := ctx.Request().Header.Get("Authorization") + token := "" + if strings.HasPrefix(authHeader, "Bearer ") { + token = authHeader[7:] // Skip the "Bearer " part + } + + if token == "" { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing token") + } + + log.Ctx(ctx.Request().Context()).Info().Msgf("Validating google id token %s for path %s and method %s", token, path, method) + + payload, err := idtoken.Validate(ctx.Request().Context(), token, "https://api.comfy.org") + if err != nil { + log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid token") + return ctx.JSON(http.StatusUnauthorized, "Invalid token") + } + + email, _ := payload.Claims["email"].(string) + if email != "cloud-scheduler@dreamboothy.iam.gserviceaccount.com" { + log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid email") + return ctx.JSON(http.StatusUnauthorized, "Invalid email") + } + + log.Ctx(ctx.Request().Context()).Info().Msgf("Service Account Email: %s", email) + return next(ctx) + } + } +} diff --git a/server/middleware/authentication/service_account_auth_test.go b/server/middleware/authentication/service_account_auth_test.go new file mode 100644 index 0000000..048d1a4 --- /dev/null +++ b/server/middleware/authentication/service_account_auth_test.go @@ -0,0 +1,67 @@ +package drip_authentication + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +func TestServiceAccountAllowList(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + middleware := ServiceAccountAuthMiddleware() + + tests := []struct { + name string + path string + method string + allowed bool + }{ + {"OpenAPI GET", "/openapi", "GET", true}, + {"Session DELETE", "/users/sessions", "DELETE", true}, + {"Health GET", "/health", "GET", true}, + {"VM ANY", "/vm", "POST", true}, + {"VM ANY GET", "/vm", "GET", true}, + {"Artifact POST", "/upload-artifact", "POST", true}, + {"Git Commit POST", "/gitcommit", "POST", true}, + {"Git Commit GET", "/gitcommit", "GET", true}, + {"Branch GET", "/branch", "GET", true}, + {"Node Version Path POST", "/publishers/pub123/nodes/node456/versions", "POST", true}, + {"Publisher POST", "/publishers", "POST", true}, + {"Unauthorized Path", "/nonexistent", "GET", true}, + {"Get All Nodes", "/nodes", "GET", true}, + {"Install Nodes", "/nodes/node-id/install", "GET", true}, + + {"Reindex Nodes", "/nodes/reindex", "POST", false}, + {"Reindex Nodes", "/security-scan", "GET", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(tt.method, tt.path, nil) + c.SetRequest(req) + handled := false + next := echo.HandlerFunc(func(c echo.Context) error { + handled = true + return nil + }) + err := middleware(next)(c) + if tt.allowed { + assert.True(t, handled, "Request should be allowed through") + assert.Nil(t, err) + } else { + assert.False(t, handled, "Request should not be allowed through") + assert.NotNil(t, err) + httpError, ok := err.(*echo.HTTPError) + assert.True(t, ok, "Error should be HTTPError") + assert.Equal(t, http.StatusUnauthorized, httpError.Code) + } + }) + } +} diff --git a/server/middleware/authorization/authorization_manager.go b/server/middleware/authorization/authorization_manager.go new file mode 100644 index 0000000..74cf1cd --- /dev/null +++ b/server/middleware/authorization/authorization_manager.go @@ -0,0 +1,210 @@ +package drip_authorization + +import ( + "context" + "net/http" + "registry-backend/drip" + "registry-backend/ent" + "registry-backend/ent/schema" + drip_authentication "registry-backend/server/middleware/authentication" + drip_services "registry-backend/services/registry" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" + "github.com/rs/zerolog/log" +) + +type Assertor interface { + AssertPublisherBanned(ctx context.Context, client *ent.Client, publisherID string) error + AssertPublisherPermissions(ctx context.Context, + client *ent.Client, + publisherID string, + userID string, + permissions []schema.PublisherPermissionType) (err error) + IsPersonalAccessTokenValidForPublisher(ctx context.Context, + client *ent.Client, + publisherID string, + accessToken string, + ) (bool, error) + AssertNodeBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, nodeID string) error + AssertAccessTokenBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, tokenId uuid.UUID) error + AssertNodeBanned(ctx context.Context, client *ent.Client, nodeID string) error +} + +// AuthorizationManager manages authorization-related tasks +type AuthorizationManager struct { + EntClient *ent.Client + Assertor Assertor +} + +// NewAuthorizationManager creates a new instance of AuthorizationManager +func NewAuthorizationManager( + entClient *ent.Client, assertor Assertor) *AuthorizationManager { + return &AuthorizationManager{ + EntClient: entClient, + Assertor: assertor, + } +} + +// assertUserBanned checks if the user is banned +func (m *AuthorizationManager) assertUserBanned() drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + v := ctx.Value(drip_authentication.UserContextKey) + userDetails, ok := v.(*drip_authentication.UserDetails) + if !ok { + return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") + } + + u, err := m.EntClient.User.Get(ctx, userDetails.ID) + if err != nil { + return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") + } + + if u.Status == schema.UserStatusTypeBanned { + return nil, echo.NewHTTPError(http.StatusForbidden, "user/publisher is banned") + } + + return f(c, request) + } + } +} + +// assertPublisherPermission checks if the user has the required permissions for the publisher +func (m *AuthorizationManager) assertPublisherPermission( + permissions []schema.PublisherPermissionType, extractor func(req interface{}) string) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + v := ctx.Value(drip_authentication.UserContextKey) + userDetails, ok := v.(*drip_authentication.UserDetails) + if !ok { + return nil, echo.NewHTTPError(http.StatusUnauthorized, "user not found") + } + publisherID := extractor(request) + + log.Ctx(ctx).Info().Msgf("Checking if user ID %s has permission "+ + "to update publisher ID %s", userDetails.ID, publisherID) + err = m.Assertor.AssertPublisherPermissions(ctx, m.EntClient, publisherID, userDetails.ID, permissions) + switch { + case ent.IsNotFound(err): + log.Ctx(ctx).Info().Msgf("Publisher with ID %s not found", publisherID) + return nil, echo.NewHTTPError(http.StatusNotFound, "Publisher Not Found") + + case drip_services.IsPermissionError(err): + log.Ctx(ctx).Error().Msgf("Permission denied for user ID %s on "+ + "publisher ID %s w/ err: %v", userDetails.ID, publisherID, err) + return nil, echo.NewHTTPError(http.StatusForbidden, "Permission denied") + + case err != nil: + log.Ctx(ctx).Error().Msgf("Failed to assert publisher "+ + "permission %s w/ err: %v", publisherID, err) + return nil, err + } + + return f(c, request) + } + } +} + +// assertNodeBanned checks if the node is banned +func (m *AuthorizationManager) assertNodeBanned(extractor func(req interface{}) string) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + nodeID := extractor(request) + err = m.Assertor.AssertNodeBanned(ctx, m.EntClient, nodeID) + switch { + case drip_services.IsPermissionError(err): + log.Ctx(ctx).Error().Msgf("Node %s banned", nodeID) + return nil, echo.NewHTTPError(http.StatusForbidden, "Node Banned") + + case err != nil: + log.Ctx(ctx).Error().Msgf("Failed to assert node ban status %s w/ err: %v", nodeID, err) + return nil, err + } + + return f(c, request) + } + } +} + +// assertPublisherBanned checks if the publisher is banned +func (m *AuthorizationManager) assertPublisherBanned(extractor func(req interface{}) string) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + publisherID := extractor(request) + + switch err = m.Assertor.AssertPublisherBanned(ctx, m.EntClient, publisherID); { + case drip_services.IsPermissionError(err): + log.Ctx(ctx).Error().Msgf("Publisher %s banned", publisherID) + return nil, echo.NewHTTPError(http.StatusForbidden, "Node Banned") + + case err != nil: + log.Ctx(ctx).Error().Msgf("Failed to assert publisher ban status %s w/ err: %v", publisherID, err) + return nil, err + } + + return f(c, request) + } + } +} + +// assertPersonalAccessTokenValid check if personal access token is valid for a publisher +func (m *AuthorizationManager) assertPersonalAccessTokenValid( + extractorPublsherID func(req interface{}) (nodeid string), + extractorPAT func(req interface{}) (pat string), +) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + pubID := extractorPublsherID(request) + pat := extractorPAT(request) + tokenValid, err := m.Assertor.IsPersonalAccessTokenValidForPublisher( + ctx, m.EntClient, pubID, pat) + if err != nil { + log.Ctx(ctx).Error().Msgf("Token validation failed w/ err: %v", err) + return nil, echo.NewHTTPError(http.StatusBadRequest, "Failed to validate token") + } + if !tokenValid { + log.Ctx(ctx).Error().Msg("Invalid personal access token") + return nil, echo.NewHTTPError(http.StatusBadRequest, "Invalid personal access token") + } + + return f(c, request) + } + } +} + +// assertNodeBelongsToPublisher check if a node belongs to a publisher +func (m *AuthorizationManager) assertNodeBelongsToPublisher( + extractorPublsherID func(req interface{}) (nodeid string), + extractorNodeID func(req interface{}) (nodeid string), +) drip.StrictMiddlewareFunc { + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + return func(c echo.Context, request interface{}) (response interface{}, err error) { + ctx := c.Request().Context() + pubID := extractorPublsherID(request) + nodeID := extractorNodeID(request) + + err = m.Assertor.AssertNodeBelongsToPublisher(ctx, m.EntClient, pubID, nodeID) + switch { + case ent.IsNotFound(err): + return f(c, request) + + case drip_services.IsPermissionError(err): + log.Ctx(ctx).Error().Msgf( + "Permission denied for publisher ID %s on node ID %s w/ err: %v", pubID, nodeID, err) + return nil, echo.NewHTTPError(http.StatusForbidden, "Permission denied") + + case err != nil: + return nil, err + } + + return f(c, request) + } + } +} diff --git a/server/middleware/authorization/authorization_middleware.go b/server/middleware/authorization/authorization_middleware.go new file mode 100644 index 0000000..12f8228 --- /dev/null +++ b/server/middleware/authorization/authorization_middleware.go @@ -0,0 +1,216 @@ +package drip_authorization + +import ( + "registry-backend/drip" + "registry-backend/ent/schema" + "slices" + + strictecho "github.com/oapi-codegen/runtime/strictmiddleware/echo" +) + +func (m *AuthorizationManager) AuthorizationMiddleware() drip.StrictMiddlewareFunc { + subMiddlewares := map[string][]drip.StrictMiddlewareFunc{ + "CreatePublisher": { + m.assertUserBanned(), + }, + "UpdatePublisher": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.UpdatePublisherRequestObject).PublisherId + }, + ), + }, + "CreateNode": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.CreateNodeRequestObject).PublisherId + }, + ), + }, + "DeleteNode": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.DeleteNodeRequestObject).PublisherId + }, + ), + m.assertNodeBelongsToPublisher( + func(req interface{}) (publisherID string) { + return req.(drip.DeleteNodeRequestObject).PublisherId + }, + func(req interface{}) (nodeID string) { + return req.(drip.DeleteNodeRequestObject).NodeId + }, + ), + }, + "UpdateNode": { + m.assertUserBanned(), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.UpdateNodeRequestObject).NodeId + }, + ), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeRequestObject).PublisherId + }, + ), + m.assertNodeBelongsToPublisher( + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeRequestObject).PublisherId + }, + func(req interface{}) (nodeID string) { + return req.(drip.UpdateNodeRequestObject).NodeId + }, + ), + }, + "GetNode": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.GetNodeRequestObject).NodeId + }, + ), + }, + "PublishNodeVersion": { + m.assertPublisherBanned( + func(req interface{}) (publisherID string) { + return req.(drip.PublishNodeVersionRequestObject).PublisherId + }), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.PublishNodeVersionRequestObject).NodeId + }, + ), + m.assertNodeBelongsToPublisher( + func(req interface{}) (publisherID string) { + return req.(drip.PublishNodeVersionRequestObject).PublisherId + }, + func(req interface{}) (NodeId string) { + return req.(drip.PublishNodeVersionRequestObject).NodeId + }, + ), + m.assertPersonalAccessTokenValid( + func(req interface{}) (publisherID string) { + return req.(drip.PublishNodeVersionRequestObject).PublisherId + }, + func(req interface{}) (pat string) { + return req.(drip.PublishNodeVersionRequestObject).Body.PersonalAccessToken + }, + ), + }, + "UpdateNodeVersion": { + m.assertUserBanned(), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.UpdateNodeVersionRequestObject).NodeId + }, + ), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeVersionRequestObject).PublisherId + }, + ), + m.assertNodeBelongsToPublisher( + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeVersionRequestObject).PublisherId + }, + func(req interface{}) (publisherID string) { + return req.(drip.UpdateNodeVersionRequestObject).NodeId + }, + ), + }, + "DeleteNodeVersion": { + m.assertUserBanned(), + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.DeleteNodeVersionRequestObject).NodeId + }, + ), + }, + "GetNodeVersion": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.GetNodeVersionRequestObject).NodeId + }, + ), + }, + "ListNodeVersions": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.ListNodeVersionsRequestObject).NodeId + }, + ), + }, + "InstallNode": { + m.assertNodeBanned( + func(req interface{}) (nodeid string) { + return req.(drip.InstallNodeRequestObject).NodeId + }, + ), + }, + "CreatePersonalAccessToken": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.CreatePersonalAccessTokenRequestObject).PublisherId + }, + ), + }, + "DeletePersonalAccessToken": { + m.assertUserBanned(), + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.DeletePersonalAccessTokenRequestObject).PublisherId + }, + ), + }, + "ListPersonalAccessTokens": { + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.ListPersonalAccessTokensRequestObject).PublisherId + }, + ), + }, + "GetPermissionOnPublisherNodes": { + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.GetPermissionOnPublisherNodesRequestObject).PublisherId + }, + ), + }, + "GetPermissionOnPublisher": { + m.assertPublisherPermission( + []schema.PublisherPermissionType{schema.PublisherPermissionTypeOwner}, + func(req interface{}) (publisherID string) { + return req.(drip.GetPermissionOnPublisherRequestObject).PublisherId + }, + ), + }, + } + for _, v := range subMiddlewares { + slices.Reverse(v) + } + + return func(f strictecho.StrictEchoHandlerFunc, operationID string) strictecho.StrictEchoHandlerFunc { + middlewares, ok := subMiddlewares[operationID] + if !ok { + return f + } + + for _, mw := range middlewares { + f = mw(f, operationID) + } + return f + } +} diff --git a/server/middleware/authorization/authorization_middleware_test.go b/server/middleware/authorization/authorization_middleware_test.go new file mode 100644 index 0000000..40f055c --- /dev/null +++ b/server/middleware/authorization/authorization_middleware_test.go @@ -0,0 +1,91 @@ +package drip_authorization + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "registry-backend/drip" + "registry-backend/ent" + "registry-backend/ent/schema" + "testing" + + "github.com/google/uuid" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" +) + +var _ Assertor = mockAlwayErrorAssertor{} +var errMockAssertor = errors.New("assertion failed") + +type mockAlwayErrorAssertor struct{} + +// AssertAccessTokenBelongsToPublisher implements Assertor. +func (m mockAlwayErrorAssertor) AssertAccessTokenBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, tokenId uuid.UUID) error { + return errors.New("assertion failed") +} + +// AssertNodeBanned implements Assertor. +func (m mockAlwayErrorAssertor) AssertNodeBanned(ctx context.Context, client *ent.Client, nodeID string) error { + return errors.New("assertion failed") +} + +// AssertNodeBelongsToPublisher implements Assertor. +func (m mockAlwayErrorAssertor) AssertNodeBelongsToPublisher(ctx context.Context, client *ent.Client, publisherID string, nodeID string) error { + return errors.New("assertion failed") +} + +// AssertPublisherBanned implements Assertor. +func (m mockAlwayErrorAssertor) AssertPublisherBanned(ctx context.Context, client *ent.Client, publisherID string) error { + return errors.New("assertion failed") +} + +// AssertPublisherPermissions implements Assertor. +func (m mockAlwayErrorAssertor) AssertPublisherPermissions(ctx context.Context, client *ent.Client, publisherID string, userID string, permissions []schema.PublisherPermissionType) (err error) { + return errors.New("assertion failed") +} + +// IsPersonalAccessTokenValidForPublisher implements Assertor. +func (m mockAlwayErrorAssertor) IsPersonalAccessTokenValidForPublisher(ctx context.Context, client *ent.Client, publisherID string, accessToken string) (bool, error) { + return false, errors.New("assertion failed") +} + +func TestNoAuthz(t *testing.T) { + mw := NewAuthorizationManager(&ent.Client{}, mockAlwayErrorAssertor{}).AuthorizationMiddleware() + req, res := httptest.NewRequest(http.MethodGet, "/", nil), httptest.NewRecorder() + ctx := echo.New().NewContext(req, res) + + tests := []struct { + op string + pass bool + req interface{} + }{ + {op: "SomeOtherOperation", pass: true}, + + {op: "CreatePublisher", pass: false, req: drip.CreatePublisherRequestObject{}}, + {op: "UpdatePublisher", pass: false, req: drip.UpdatePublisherRequestObject{}}, + {op: "CreateNode", pass: false, req: drip.CreateNodeRequestObject{}}, + {op: "DeleteNode", pass: false, req: drip.DeleteNodeRequestObject{}}, + {op: "UpdateNode", pass: false, req: drip.UpdateNodeRequestObject{}}, + {op: "GetNode", pass: false, req: drip.GetNodeRequestObject{}}, + {op: "PublishNodeVersion", pass: false, req: drip.PublishNodeVersionRequestObject{}}, + {op: "UpdateNodeVersion", pass: false, req: drip.UpdateNodeVersionRequestObject{}}, + {op: "DeleteNodeVersion", pass: false, req: drip.DeleteNodeVersionRequestObject{}}, + {op: "GetNodeVersion", pass: false, req: drip.GetNodeVersionRequestObject{}}, + {op: "ListNodeVersions", pass: false, req: drip.ListNodeVersionsRequestObject{}}, + {op: "InstallNode", pass: false, req: drip.InstallNodeRequestObject{}}, + {op: "CreatePersonalAccessToken", pass: false, req: drip.CreatePersonalAccessTokenRequestObject{}}, + {op: "DeletePersonalAccessToken", pass: false, req: drip.DeletePersonalAccessTokenRequestObject{}}, + {op: "GetPermissionOnPublisherNodes", pass: false, req: drip.GetPermissionOnPublisherNodesRequestObject{}}, + {op: "GetPermissionOnPublisher", pass: false, req: drip.GetPermissionOnPublisherRequestObject{}}, + } + for _, test := range tests { + handled := false + h := func(ctx echo.Context, request interface{}) (interface{}, error) { + handled = true + return nil, nil + } + mw(h, test.op)(ctx, test.req) + assert.Equal(t, test.pass, handled) + } +} diff --git a/server/middleware/metric/metric.go b/server/middleware/metric/metric.go new file mode 100644 index 0000000..ad3a900 --- /dev/null +++ b/server/middleware/metric/metric.go @@ -0,0 +1,98 @@ +package drip_metric + +import ( + "context" + "sort" + "strings" + "sync" + "sync/atomic" + "time" + + "cloud.google.com/go/monitoring/apiv3/v2/monitoringpb" + metricpb "google.golang.org/genproto/googleapis/api/metric" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func (c *CounterMetric) Increment(key any, i int64) int64 { + v, _ := c.LoadOrStore(key, new(atomic.Int64)) + ai, ok := v.(*atomic.Int64) + if !ok { + ai = new(atomic.Int64) + } + ai.Add(i) // Initialize and increment atomically + return ai.Load() +} + +type customCounterKey struct { + t string + l string +} + +type CustomCounterIncrement struct { + Type string + Labels map[string]string + Val int64 +} + +func (c CustomCounterIncrement) key() customCounterKey { + keys := make([]string, 0, len(c.Labels)) + for k, v := range c.Labels { + keys = append(keys, k+":"+v) + } + sort.Strings(keys) + return customCounterKey{ + t: c.Type, + l: strings.Join(keys, ","), + } +} + +var customCounterCtxKey = struct{}{} + +func AttachCustomCounterMetric(ctx context.Context) context.Context { + ctx = context.WithValue(ctx, customCounterCtxKey, &([]CustomCounterIncrement{})) + return ctx +} + +func IncrementCustomCounterMetric(ctx context.Context, inc CustomCounterIncrement) { + v := ctx.Value(customCounterCtxKey) + cc, ok := v.(*[]CustomCounterIncrement) + if !ok || cc == nil { + return + } + *cc = append(*cc, inc) +} + +var customCounterMetric = CounterMetric{Map: sync.Map{}} + +func CreateCustomCounterMetrics(ctx context.Context) (ts []*monitoringpb.TimeSeries) { + v := ctx.Value(customCounterCtxKey) + cc, ok := v.(*[]CustomCounterIncrement) + if !ok || cc == nil { + return + } + + for _, c := range *cc { + val := customCounterMetric.Increment(c.key(), c.Val) + ts = append(ts, &monitoringpb.TimeSeries{ + Metric: &metricpb.Metric{ + Type: MetricTypePrefix + "/" + c.Type, + Labels: c.Labels, + }, + MetricKind: metricpb.MetricDescriptor_CUMULATIVE, + Points: []*monitoringpb.Point{ + { + Interval: &monitoringpb.TimeInterval{ + StartTime: timestamppb.New(time.Now().Add(-time.Second)), + EndTime: timestamppb.New(time.Now()), + }, + Value: &monitoringpb.TypedValue{ + Value: &monitoringpb.TypedValue_Int64Value{ + Int64Value: val, + }, + }, + }, + }, + }) + } + return +} diff --git a/server/middleware/metric_middleware.go b/server/middleware/metric/metric_middleware.go similarity index 79% rename from server/middleware/metric_middleware.go rename to server/middleware/metric/metric_middleware.go index c7e63de..cf53b4f 100644 --- a/server/middleware/metric_middleware.go +++ b/server/middleware/metric/metric_middleware.go @@ -1,4 +1,4 @@ -package drip_middleware +package drip_metric import ( "context" @@ -18,9 +18,19 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -const MetricTypePrefix = "custom.googleapis.com/comfy_api_frontend" +const ( + MetricTypePrefix = "custom.googleapis.com/comfy_api_frontend" + batchInterval = 5 * time.Minute // Batch interval for sending metrics +) + +var ( + environment = os.Getenv("DRIP_ENV") + metricsCh = make(chan *monitoringpb.TimeSeries, 1000) +) -var environment = os.Getenv("DRIP_ENV") +func init() { + go processMetricsBatch() +} // MetricsMiddleware creates a middleware to capture and send metrics for HTTP requests. func MetricsMiddleware(client *monitoring.MetricClient, config *config.Config) echo.MiddlewareFunc { @@ -32,7 +42,7 @@ func MetricsMiddleware(client *monitoring.MetricClient, config *config.Config) e // Generate metrics for the request duration, count, and errors. if config.DripEnv != "localdev" { - sendMetrics(c.Request().Context(), client, config, + enqueueMetrics( createDurationMetric(c, startTime, endTime), createRequestMetric(c), createErrorMetric(c, err), @@ -79,23 +89,57 @@ func (e EndpointMetricKey) toLabels() map[string]string { } } -// sendMetrics sends a batch of time series data to Cloud Monitoring. -func sendMetrics( - ctx context.Context, - client *monitoring.MetricClient, - config *config.Config, - series ...*monitoringpb.TimeSeries, -) { - req := &monitoringpb.CreateTimeSeriesRequest{ - Name: "projects/" + config.ProjectID, - TimeSeries: make([]*monitoringpb.TimeSeries, 0, len(series)), - } - +func enqueueMetrics(series ...*monitoringpb.TimeSeries) { for _, s := range series { if s != nil { - req.TimeSeries = append(req.TimeSeries, s) + metricsCh <- s + } + } +} + +func processMetricsBatch() { + ticker := time.NewTicker(batchInterval) + for range ticker.C { + sendBatchedMetrics() + } +} + +func sendBatchedMetrics() { + var series []*monitoringpb.TimeSeries + for { + select { + case s := <-metricsCh: + series = append(series, s) + if len(series) >= 1000 { + sendMetrics(series) + series = nil + } + default: + if len(series) > 0 { + sendMetrics(series) + } + return } } +} + +func sendMetrics(series []*monitoringpb.TimeSeries) { + if len(series) == 0 { + return + } + + ctx := context.Background() + client, err := monitoring.NewMetricClient(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msg("Failed to create metric client") + return + } + defer client.Close() + + req := &monitoringpb.CreateTimeSeriesRequest{ + Name: "projects/" + os.Getenv("PROJECT_ID"), + TimeSeries: series, + } if err := client.CreateTimeSeries(ctx, req); err != nil { log.Ctx(ctx).Error().Err(err).Msg("Failed to create time series") diff --git a/server/middleware/service_account_auth.go b/server/middleware/service_account_auth.go deleted file mode 100644 index cd35d94..0000000 --- a/server/middleware/service_account_auth.go +++ /dev/null @@ -1,60 +0,0 @@ -package drip_middleware - -import ( - "net/http" - "strings" - - "github.com/labstack/echo/v4" - "github.com/rs/zerolog/log" - "google.golang.org/api/idtoken" -) - -func ServiceAccountAuthMiddleware() echo.MiddlewareFunc { - // Handlers in here should be checked by this middleware. - var checklist = map[string][]string{ - "/users/sessions": {"DELETE"}, - } - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - path := ctx.Request().URL.Path - method := ctx.Request().Method - - // Check if the request path and method are in the checklist - if methods, ok := checklist[path]; ok { - for _, allowMethod := range methods { - if method == allowMethod { - authHeader := ctx.Request().Header.Get("Authorization") - token := "" - if strings.HasPrefix(authHeader, "Bearer ") { - token = authHeader[7:] // Skip the "Bearer " part - } - - if token == "" { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing token") - } - - log.Ctx(ctx.Request().Context()).Info().Msgf("Validating google id token %s for path %s and method %s", token, path, method) - - payload, err := idtoken.Validate(ctx.Request().Context(), token, "https://api.comfy.org") - - if err == nil { - if email, ok := payload.Claims["email"].(string); ok { - log.Ctx(ctx.Request().Context()).Info().Msgf("Service Account Email: %s", email) - // TODO(robinhuang): Make service account an environment variable. - if email == "stop-vm-sa@dreamboothy.iam.gserviceaccount.com" { - return next(ctx) - } - } - } - - log.Ctx(ctx.Request().Context()).Error().Err(err).Msg("Invalid token") - return ctx.JSON(http.StatusUnauthorized, "Invalid token") - } - } - } - - // Proceed with the next middleware or handler - return next(ctx) - } - } -} diff --git a/server/server.go b/server/server.go index d804f74..10e1a01 100644 --- a/server/server.go +++ b/server/server.go @@ -6,11 +6,16 @@ import ( "registry-backend/config" generated "registry-backend/drip" "registry-backend/ent" + "registry-backend/gateways/algolia" + "registry-backend/gateways/discord" gateway "registry-backend/gateways/slack" "registry-backend/gateways/storage" handler "registry-backend/server/handlers" "registry-backend/server/implementation" drip_middleware "registry-backend/server/middleware" + drip_authentication "registry-backend/server/middleware/authentication" + drip_authorization "registry-backend/server/middleware/authorization" + drip_metric "registry-backend/server/middleware/metric" "strings" monitoring "cloud.google.com/go/monitoring/apiv3/v2" @@ -47,6 +52,7 @@ func (s *Server) Start() error { ".comfyci.org", // Any subdomain of comfyci.org os.Getenv("CORS_ORIGIN"), // Environment-specific allowed origin ".comfyregistry.org", + ".comfy.org", } for _, allowed := range allowedOrigins { @@ -82,7 +88,12 @@ func (s *Server) Start() error { return err } - slackService := gateway.NewSlackService() + slackService := gateway.NewSlackService(s.Config) + algoliaService, err := algolia.NewFromEnvOrNoop() + if err != nil { + return err + } + discordService := discord.NewDiscordService(s.Config) mon, err := monitoring.NewMetricClient(context.Background()) if err != nil { @@ -90,10 +101,15 @@ func (s *Server) Start() error { } // Attach implementation of generated oapi strict server. - impl := implementation.NewStrictServerImplementation(s.Client, s.Config, storageService, slackService) + impl := implementation.NewStrictServerImplementation(s.Client, s.Config, storageService, slackService, discordService, algoliaService) - var middlewares []generated.StrictMiddlewareFunc + // Define middlewares in the order of operations + authorizationManager := drip_authorization.NewAuthorizationManager(s.Client, impl.RegistryService) + middlewares := []generated.StrictMiddlewareFunc{ + authorizationManager.AuthorizationMiddleware(), + } wrapped := generated.NewStrictHandler(impl, middlewares) + generated.RegisterHandlers(e, wrapped) e.GET("/openapi", handler.SwaggerHandler) @@ -102,9 +118,10 @@ func (s *Server) Start() error { }) // Global Middlewares - e.Use(drip_middleware.MetricsMiddleware(mon, s.Config)) - e.Use(drip_middleware.FirebaseMiddleware(s.Client)) - e.Use(drip_middleware.ServiceAccountAuthMiddleware()) + e.Use(drip_metric.MetricsMiddleware(mon, s.Config)) + e.Use(drip_authentication.FirebaseAuthMiddleware(s.Client)) + e.Use(drip_authentication.ServiceAccountAuthMiddleware()) + e.Use(drip_authentication.JWTAdminAuthMiddleware(s.Client, s.Config.JWTSecret)) e.Use(drip_middleware.ErrorLoggingMiddleware()) e.Logger.Fatal(e.Start(":8080")) diff --git a/services/comfy_ci/comfy_ci_svc.go b/services/comfy_ci/comfy_ci_svc.go index f1d6b69..3b17466 100644 --- a/services/comfy_ci/comfy_ci_svc.go +++ b/services/comfy_ci/comfy_ci_svc.go @@ -8,6 +8,8 @@ import ( "registry-backend/ent" "registry-backend/ent/ciworkflowresult" "registry-backend/ent/gitcommit" + "registry-backend/mapper" + drip_metric "registry-backend/server/middleware/metric" "strings" "time" @@ -37,9 +39,19 @@ func (s *ComfyCIService) ProcessCIRequest(ctx context.Context, client *ent.Clien existingCommit, err := client.GitCommit.Query().Where(gitcommit.CommitHashEQ(req.Body.CommitHash)).Where(gitcommit.RepoNameEQ(req.Body.Repo)).Only(ctx) if ent.IsNotSingular(err) { log.Ctx(ctx).Error().Err(err).Msgf("Failed to query git commit %s", req.Body.CommitHash) + drip_metric.IncrementCustomCounterMetric(ctx, drip_metric.CustomCounterIncrement{ + Type: "ci-git-commit-query-error", + Val: 1, + Labels: map[string]string{}, + }) } if existingCommit != nil { - _, err := client.CIWorkflowResult.Delete().Where(ciworkflowresult.HasGitcommitWith(gitcommit.IDEQ(existingCommit.ID))).Exec(ctx) + log.Ctx(ctx).Info().Msgf("Deleting existing run results for git commit %s, operating system %s, and workflow name %s", req.Body.CommitHash, req.Body.Os, req.Body.WorkflowName) + _, err := client.CIWorkflowResult.Delete().Where( + ciworkflowresult.HasGitcommitWith(gitcommit.IDEQ(existingCommit.ID)), + ciworkflowresult.WorkflowName(req.Body.WorkflowName), + ciworkflowresult.OperatingSystem(req.Body.Os), + ).Exec(ctx) if err != nil { log.Ctx(ctx).Error().Err(err).Msgf("Failed to delete existing run results for git commit %s", req.Body.CommitHash) return err @@ -47,11 +59,38 @@ func (s *ComfyCIService) ProcessCIRequest(ctx context.Context, client *ent.Clien } return db.WithTx(ctx, client, func(tx *ent.Tx) error { - id, err := s.UpsertCommit(ctx, tx.Client(), req.Body.CommitHash, req.Body.BranchName, req.Body.Repo, req.Body.CommitTime, req.Body.CommitMessage) + id, err := s.UpsertCommit(ctx, tx.Client(), req.Body.CommitHash, req.Body.BranchName, req.Body.Repo, req.Body.CommitTime, req.Body.CommitMessage, req.Body.PrNumber, req.Body.Author) if err != nil { return err } gitcommit := tx.Client().GitCommit.GetX(ctx, id) + + // Create the CI Workflow Result first. Then add files to it (if there are any). + cudaVersion := "" + if req.Body.CudaVersion != nil { + cudaVersion = *req.Body.CudaVersion + } + avgVram := 0 + if req.Body.AvgVram != nil { + avgVram = *req.Body.AvgVram + } + peakVram := 0 + if req.Body.PeakVram != nil { + peakVram = *req.Body.PeakVram + } + pytorchVersion := "" + if req.Body.PytorchVersion != nil { + pytorchVersion = *req.Body.PytorchVersion + } + comfyRunFlags := "" + if req.Body.ComfyRunFlags != nil { + comfyRunFlags = *req.Body.ComfyRunFlags + } + workflowResultId, err := s.UpsertRunResult(ctx, tx.Client(), gitcommit, req.Body.Os, cudaVersion, req.Body.WorkflowName, req.Body.RunId, req.Body.JobId, req.Body.StartTime, req.Body.EndTime, avgVram, peakVram, req.Body.PythonVersion, pytorchVersion, req.Body.JobTriggerUser, comfyRunFlags, req.Body.Status, req.Body.MachineStats) + if err != nil { + return err + } + if req.Body.OutputFilesGcsPaths != nil && req.Body.BucketName != nil { files, err := GetPublicUrlForOutputFiles(ctx, *req.Body.BucketName, *req.Body.OutputFilesGcsPaths) if err != nil { @@ -64,18 +103,16 @@ func (s *ComfyCIService) ProcessCIRequest(ctx context.Context, client *ent.Clien if err != nil { log.Ctx(ctx).Error().Err(err).Msg("Failed to upsert storage file") + drip_metric.IncrementCustomCounterMetric(ctx, drip_metric.CustomCounterIncrement{ + Type: "ci-upsert-storage-error", + Val: 1, + Labels: map[string]string{ + "bucket-name": file.BucketName, + }, + }) continue } - - cudaVersion := "" - if req.Body.CudaVersion != nil { - cudaVersion = *req.Body.CudaVersion - } - - _, err = s.UpsertRunResult(ctx, tx.Client(), file, gitcommit, req.Body.Os, cudaVersion, req.Body.WorkflowName, req.Body.RunId, req.Body.StartTime, req.Body.EndTime) - if err != nil { - return err - } + tx.Client().CIWorkflowResult.UpdateOneID(workflowResultId).AddStorageFile(file).Exec(ctx) } } return nil @@ -83,7 +120,7 @@ func (s *ComfyCIService) ProcessCIRequest(ctx context.Context, client *ent.Clien } // UpsertCommit creates or updates a GitCommit entity. -func (s *ComfyCIService) UpsertCommit(ctx context.Context, client *ent.Client, hash, branchName, repoName string, commitIsoTime string, commitMessage string) (uuid.UUID, error) { +func (s *ComfyCIService) UpsertCommit(ctx context.Context, client *ent.Client, hash, branchName, repoName, commitIsoTime, commitMessage, prNumber, author string) (uuid.UUID, error) { log.Ctx(ctx).Info().Msgf("Upserting commit %s", hash) commitTime, err := time.Parse(time.RFC3339, commitIsoTime) if err != nil { @@ -97,6 +134,8 @@ func (s *ComfyCIService) UpsertCommit(ctx context.Context, client *ent.Client, h SetRepoName(strings.ToLower(repoName)). // TODO(robinhuang): Write test for this. SetCommitTimestamp(commitTime). SetCommitMessage(commitMessage). + SetPrNumber(prNumber). + SetAuthor(author). OnConflict( // Careful, the order matters here. sql.ConflictColumns(gitcommit.FieldRepoName, gitcommit.FieldCommitHash), @@ -111,17 +150,30 @@ func (s *ComfyCIService) UpsertCommit(ctx context.Context, client *ent.Client, h } // UpsertRunResult creates or updates a ActionRunResult entity. -func (s *ComfyCIService) UpsertRunResult(ctx context.Context, client *ent.Client, file *ent.StorageFile, gitcommit *ent.GitCommit, os, gpuType, workflowName, runId string, startTime, endTime int64) (uuid.UUID, error) { +func (s *ComfyCIService) UpsertRunResult(ctx context.Context, client *ent.Client, gitcommit *ent.GitCommit, os, cudaVersion, workflowName, runId, jobId string, startTime, endTime int64, avgVram, peakVram int, pythonVersion, pytorchVersion, jobTriggerUser, comfyRunFlags string, status drip.WorkflowRunStatus, machineStats *drip.MachineStats) (uuid.UUID, error) { log.Ctx(ctx).Info().Msgf("Upserting workflow result for commit %s", gitcommit.CommitHash) + dbWorkflowRunStatus, err := mapper.ApiWorkflowRunStatusToDb(status) + if err != nil { + return uuid.Nil, err + } return client.CIWorkflowResult. Create(). SetGitcommit(gitcommit). - SetStorageFile(file). SetOperatingSystem(os). SetWorkflowName(workflowName). SetRunID(runId). + SetJobID(jobId). SetStartTime(startTime). SetEndTime(endTime). + SetPythonVersion(pythonVersion). + SetPytorchVersion(pytorchVersion). + SetCudaVersion(cudaVersion). + SetComfyRunFlags(comfyRunFlags). + SetAvgVram(avgVram). + SetPeakVram(peakVram). + SetStatus(dbWorkflowRunStatus). + SetJobTriggerUser(jobTriggerUser). + SetMetadata(mapper.MachineStatsToMap(machineStats)). OnConflict( sql.ConflictColumns(ciworkflowresult.FieldID), ). @@ -129,6 +181,24 @@ func (s *ComfyCIService) UpsertRunResult(ctx context.Context, client *ent.Client ID(ctx) } +func (s *ComfyCIService) UpdateWorkflowResult(ctx context.Context, client *ent.Client, id uuid.UUID, status drip.WorkflowRunStatus, files []*drip.StorageFile) error { + dbWorkflowRunStatus, err := mapper.ApiWorkflowRunStatusToDb(status) + if err != nil { + return err + } + + fileIds := make([]uuid.UUID, 0, len(files)) + for _, file := range files { + fileIds = append(fileIds, *file.Id) + } + + return client.CIWorkflowResult. + UpdateOneID(id). + AddStorageFileIDs(fileIds...). + SetStatus(dbWorkflowRunStatus). + Exec(ctx) +} + // UpsertStorageFile creates or updates a RunFile entity. func (s *ComfyCIService) UpsertStorageFile(ctx context.Context, client *ent.Client, publicUrl, bucketName, filePath, fileType string) (*ent.StorageFile, error) { log.Ctx(ctx).Info().Msgf("Upserting storage file for URL %s", publicUrl) diff --git a/services/registry/registry_svc.go b/services/registry/registry_svc.go index 0eff8c9..68244bc 100644 --- a/services/registry/registry_svc.go +++ b/services/registry/registry_svc.go @@ -1,22 +1,35 @@ -package drip_services +package dripservices import ( + "bytes" "context" + "encoding/json" "errors" "fmt" + "io" + "net/http" + "registry-backend/config" "registry-backend/db" "registry-backend/drip" "registry-backend/ent" "registry-backend/ent/node" "registry-backend/ent/nodeversion" "registry-backend/ent/personalaccesstoken" + "registry-backend/ent/predicate" "registry-backend/ent/publisher" "registry-backend/ent/publisherpermission" "registry-backend/ent/schema" + "registry-backend/ent/user" + "registry-backend/gateways/algolia" + "registry-backend/gateways/discord" gateway "registry-backend/gateways/slack" "registry-backend/gateways/storage" "registry-backend/mapper" + drip_metric "registry-backend/server/middleware/metric" + "strings" + "time" + "entgo.io/ent/dialect/sql" "github.com/Masterminds/semver/v3" "google.golang.org/protobuf/proto" @@ -27,12 +40,18 @@ import ( type RegistryService struct { storageService storage.StorageService slackService gateway.SlackService + algolia algolia.AlgoliaService + discordService discord.DiscordService + config *config.Config } -func NewRegistryService(storageSvc storage.StorageService, slackSvc gateway.SlackService) *RegistryService { +func NewRegistryService(storageSvc storage.StorageService, slackSvc gateway.SlackService, discordSvc discord.DiscordService, algoliaSvc algolia.AlgoliaService, config *config.Config) *RegistryService { return &RegistryService{ storageService: storageSvc, slackService: slackSvc, + discordService: discordSvc, + algolia: algoliaSvc, + config: config, } } @@ -42,15 +61,23 @@ type PublisherFilter struct { // NodeFilter holds optional parameters for filtering node results type NodeFilter struct { - PublisherID string - // Add more filter fields here + PublisherID string + Search string + IncludeBanned bool +} + +type NodeVersionFilter struct { + NodeId string + Status []schema.NodeVersionStatus + MinAge time.Duration + PageSize int + Page int } type NodeData struct { ID string `json:"id"` Name string `json:"name"` PublisherID string `json:"publisherId"` - // Add other fields as necessary } // ListNodesResult is the structure that holds the paginated result of nodes @@ -62,8 +89,17 @@ type ListNodesResult struct { TotalPages int `json:"totalPages"` } +type ListNodeVersionsResult struct { + Total int `json:"total"` + NodeVersions []*ent.NodeVersion `json:"nodes"` + Page int `json:"page"` + Limit int `json:"limit"` + TotalPages int `json:"totalPages"` +} + // ListNodes retrieves a paginated list of nodes with optional filtering. func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, page, limit int, filter *NodeFilter) (*ListNodesResult, error) { + // Ensure valid pagination parameters if page < 1 { page = 1 } @@ -71,23 +107,63 @@ func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, pag limit = 10 } - query := client.Node.Query().WithPublisher().WithVersions( - func(q *ent.NodeVersionQuery) { - q.Order(ent.Desc(nodeversion.FieldCreateTime)) - }, - ) + // Initialize the query with relationships + query := client.Node.Query().WithPublisher() + + // Apply filters if provided if filter != nil { + var predicates []predicate.Node + + // Filter by PublisherID if filter.PublisherID != "" { - query.Where(node.PublisherID(filter.PublisherID)) + predicates = append(predicates, node.PublisherID(filter.PublisherID)) + } + + // Filter by search term across multiple fields + if filter.Search != "" { + predicates = append(predicates, node.Or( + node.IDContainsFold(filter.Search), + node.NameContainsFold(filter.Search), + node.DescriptionContainsFold(filter.Search), + node.AuthorContainsFold(filter.Search), + )) + } + + // Exclude banned nodes if not requested + if !filter.IncludeBanned { + predicates = append(predicates, node.StatusNEQ(schema.NodeStatusBanned)) + } + + // Apply predicates to the query + if len(predicates) > 1 { + query.Where(node.And(predicates...)) + } else if len(predicates) == 1 { + query.Where(predicates[0]) } } + + // Calculate pagination offset offset := (page - 1) * limit + + // Count total nodes total, err := query.Count(ctx) if err != nil { return nil, fmt.Errorf("failed to count nodes: %w", err) } + // Fetch nodes with pagination nodes, err := query. + WithVersions(func(q *ent.NodeVersionQuery) { + q.Modify(func(s *sql.Selector) { + s.Where(sql.ExprP( + `(node_id, create_time) IN ( + SELECT node_id, MAX(create_time) + FROM node_versions + GROUP BY node_id + )`, + )) + }) + }). Offset(offset). Limit(limit). All(ctx) @@ -95,11 +171,13 @@ func (s *RegistryService) ListNodes(ctx context.Context, client *ent.Client, pag return nil, fmt.Errorf("failed to list nodes: %w", err) } + // Calculate total pages totalPages := total / limit if total%limit != 0 { - totalPages += 1 + totalPages++ } + // Return the result return &ListNodesResult{ Total: total, Nodes: nodes, @@ -229,29 +307,49 @@ func (s *RegistryService) CreateNode(ctx context.Context, client *ent.Client, pu return nil, fmt.Errorf("invalid node: %w", validNode) } - createNode, err := mapper.ApiCreateNodeToDb(publisherId, node, client) - log.Ctx(ctx).Info().Msgf("creating node with fields: %v", createNode.Mutation().Fields()) - if err != nil { - return nil, fmt.Errorf("failed to map node: %w", err) - } + var createdNode *ent.Node + err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { + createNode, err := mapper.ApiCreateNodeToDb(publisherId, node, tx.Client()) + log.Ctx(ctx).Info().Msgf("creating node with fields: %v", createNode.Mutation().Fields()) + if err != nil { + return fmt.Errorf("failed to map node: %w", err) + } - createdNode, err := createNode.Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to create node: %w", err) - } + createdNode, err = createNode.Save(ctx) + if err != nil { + return fmt.Errorf("failed to create node: %w", err) + } + + err = s.algolia.IndexNodes(ctx, createdNode) + if err != nil { + return fmt.Errorf("failed to index node: %w", err) + } - return createdNode, nil + return + }) + + return createdNode, err } -func (s *RegistryService) UpdateNode(ctx context.Context, client *ent.Client, update *ent.NodeUpdateOne) (*ent.Node, error) { - log.Ctx(ctx).Info().Msgf("updating node fields: %v", update.Mutation().Fields()) - node, err := update. - Save(ctx) +func (s *RegistryService) UpdateNode(ctx context.Context, client *ent.Client, updateFunc func(client *ent.Client) *ent.NodeUpdateOne) (*ent.Node, error) { + var node *ent.Node + err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { + update := updateFunc(tx.Client()) + log.Ctx(ctx).Info().Msgf("updating node fields: %v", update.Mutation().Fields()) - if err != nil { - return nil, fmt.Errorf("failed to update node: %w", err) - } - return node, nil + node, err = update.Save(ctx) + if err != nil { + return fmt.Errorf("failed to update node: %w", err) + } + + err = s.algolia.IndexNodes(ctx, node) + if err != nil { + return fmt.Errorf("failed to index node: %w", err) + } + + return err + }) + return node, err } func (s *RegistryService) GetNode(ctx context.Context, client *ent.Client, nodeID string) (*ent.Node, error) { @@ -303,9 +401,16 @@ func (s *RegistryService) CreateNodeVersion( return nil, fmt.Errorf("failed to create node version: %w", err) } - slackErr := s.slackService.SendRegistryMessageToSlack(fmt.Sprintf("Version %s of node %s was published successfully. Publisher: %s. https://comfyregistry.org/nodes/%s", createdNodeVersion.Version, createdNodeVersion.NodeID, publisherID, nodeID)) + message := fmt.Sprintf("Version %s of node %s was published successfully. Publisher: %s. https://registry.comfy.org/nodes/%s", createdNodeVersion.Version, createdNodeVersion.NodeID, publisherID, nodeID) + slackErr := s.slackService.SendRegistryMessageToSlack(message) + s.discordService.SendSecurityCouncilMessage(message) if slackErr != nil { log.Ctx(ctx).Error().Msgf("Failed to send message to Slack w/ err: %v", slackErr) + drip_metric.IncrementCustomCounterMetric(ctx, drip_metric.CustomCounterIncrement{ + Type: "slack-send-error", + Val: 1, + Labels: map[string]string{}, + }) } return &NodeVersionCreation{ @@ -320,21 +425,96 @@ type NodeVersionCreation struct { SignedUrl string } -func (s *RegistryService) ListNodeVersions(ctx context.Context, client *ent.Client, nodeID string) ([]*ent.NodeVersion, error) { - log.Ctx(ctx).Info().Msgf("listing node versions: %v", nodeID) - versions, err := client.NodeVersion.Query(). - Where(nodeversion.NodeIDEQ(nodeID)). +func (s *RegistryService) ListNodeVersions(ctx context.Context, client *ent.Client, filter *NodeVersionFilter) (*ListNodeVersionsResult, error) { + query := client.NodeVersion.Query(). WithStorageFile(). - Order(ent.Desc(nodeversion.FieldCreateTime)). - All(ctx) + Order(ent.Desc(nodeversion.FieldCreateTime)) + + if filter.NodeId != "" { + log.Ctx(ctx).Info().Msgf("listing node versions: %v", filter.NodeId) + query.Where(nodeversion.NodeIDEQ(filter.NodeId)) + } + + if filter.Status != nil && len(filter.Status) > 0 { + log.Ctx(ctx).Info().Msgf("listing node versions with status: %v", filter.Status) + query.Where(nodeversion.StatusIn(filter.Status...)) + } + + if filter.MinAge > 0 { + query.Where(nodeversion.CreateTimeLT(time.Now().Add(-filter.MinAge))) + } + + if filter.Page > 0 && filter.PageSize > 0 { + query.Offset((filter.Page - 1) * filter.PageSize) + query.Limit(filter.PageSize) + } + total, err := query.Count(ctx) + if err != nil { + return nil, fmt.Errorf("failed to count node versions: %w", err) + } + versions, err := query.All(ctx) if err != nil { return nil, fmt.Errorf("failed to list node versions: %w", err) } - return versions, nil + + totalPages := 0 + if total > 0 && filter.PageSize > 0 { + totalPages = total / filter.PageSize + + if total%filter.PageSize != 0 { + totalPages += 1 + } + } + return &ListNodeVersionsResult{ + Total: total, + NodeVersions: versions, + Page: filter.Page, + Limit: filter.PageSize, + TotalPages: totalPages, + }, nil +} + +func (s *RegistryService) AddNodeReview(ctx context.Context, client *ent.Client, nodeId, userID string, star int) (nv *ent.Node, err error) { + log.Ctx(ctx).Info().Msgf("add review to node: %v ", nodeId) + + err = db.WithTx(ctx, client, func(tx *ent.Tx) error { + v, err := s.GetNode(ctx, tx.Client(), nodeId) + if err != nil { + return fmt.Errorf("fail to fetch node version") + } + + err = tx.NodeReview.Create(). + SetNode(v). + SetUserID(userID). + SetStar(star). + Exec(ctx) + if err != nil { + return fmt.Errorf("fail to add review to node ") + } + + err = v.Update().AddTotalReview(1).AddTotalStar(int64(star)).Exec(ctx) + if err != nil { + return fmt.Errorf("fail to add review: %w", err) + } + + nv, err = s.GetNode(ctx, tx.Client(), nodeId) + if err != nil { + return fmt.Errorf("fail to fetch node s") + } + + err = s.algolia.IndexNodes(ctx, nv) + if err != nil { + return fmt.Errorf("failed to index node: %w", err) + } + + return nil + }) + + return } -func (s *RegistryService) GetNodeVersion(ctx context.Context, client *ent.Client, nodeId, nodeVersion string) (*ent.NodeVersion, error) { - log.Ctx(ctx).Info().Msgf("getting node version: %v", nodeVersion) +func (s *RegistryService) GetNodeVersionByVersion(ctx context.Context, client *ent.Client, nodeId, nodeVersion string) (*ent.NodeVersion, error) { + log.Ctx(ctx).Info().Msgf("getting node version %v@%v", nodeId, nodeVersion) return client.NodeVersion. Query(). Where(nodeversion.VersionEQ(nodeVersion)). @@ -352,22 +532,43 @@ func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Cli return node, nil } +func (s *RegistryService) RecordNodeInstalation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) { + var n *ent.Node + err := db.WithTx(ctx, client, func(tx *ent.Tx) (err error) { + node, err = tx.Node.UpdateOne(node).AddTotalInstall(1).Save(ctx) + if err != nil { + return err + } + err = s.algolia.IndexNodes(ctx, node) + if err != nil { + return fmt.Errorf("failed to index node: %w", err) + } + return + }) + return n, err +} + func (s *RegistryService) GetLatestNodeVersion(ctx context.Context, client *ent.Client, nodeId string) (*ent.NodeVersion, error) { - log.Ctx(ctx).Info().Msgf("getting latest version of node: %v", nodeId) + log.Ctx(ctx).Info().Msgf("Getting latest version of node: %v", nodeId) nodeVersion, err := client.NodeVersion. Query(). Where(nodeversion.NodeIDEQ(nodeId)). + //Where(nodeversion.StatusEQ(schema.NodeVersionStatusActive)). Order(ent.Desc(nodeversion.FieldCreateTime)). WithStorageFile(). First(ctx) if err != nil { if ent.IsNotFound(err) { - + log.Ctx(ctx).Info().Msgf("No versions found for node %v", nodeId) return nil, nil } + + log.Ctx(ctx).Error().Msgf("Error fetching latest version for node %v: %v", nodeId, err) return nil, err } + + log.Ctx(ctx).Info().Msgf("Found latest version for node %v: %v", nodeId, nodeVersion) return nodeVersion, nil } @@ -480,10 +681,17 @@ func (s *RegistryService) DeletePublisher(ctx context.Context, client *ent.Clien func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, nodeID string) error { log.Ctx(ctx).Info().Msgf("deleting node: %v", nodeID) - err := client.Node.DeleteOneID(nodeID).Exec(ctx) - if err != nil { - return fmt.Errorf("failed to delete node: %w", err) - } + db.WithTx(ctx, client, func(tx *ent.Tx) error { + err := tx.Client().Node.DeleteOneID(nodeID).Exec(ctx) + if err != nil { + return fmt.Errorf("failed to delete node: %w", err) + } + + if err = s.algolia.DeleteNode(ctx, &ent.Node{ID: nodeID}); err != nil { + return fmt.Errorf("fail to delete node from algolia: %w", err) + } + return nil + }) return nil } @@ -504,6 +712,209 @@ func IsPermissionError(err error) bool { if err == nil { return false } - var e *errorPermission + var e errorPermission return errors.As(err, &e) } + +func (s *RegistryService) BanPublisher(ctx context.Context, client *ent.Client, id string) error { + log.Ctx(ctx).Info().Msgf("banning publisher: %v", id) + pub, err := client.Publisher.Get(ctx, id) + if err != nil { + return fmt.Errorf("fail to find publisher: %w", err) + } + + err = db.WithTx(ctx, client, func(tx *ent.Tx) error { + err = pub.Update().SetStatus(schema.PublisherStatusTypeBanned).Exec(ctx) + if err != nil { + return fmt.Errorf("fail to update publisher: %w", err) + } + + err = tx.User.Update(). + Where(user.HasPublisherPermissionsWith(publisherpermission.HasPublisherWith(publisher.IDEQ(pub.ID)))). + SetStatus(schema.UserStatusTypeBanned). + Exec(ctx) + if err != nil { + return fmt.Errorf("fail to update users: %w", err) + } + + err = tx.Node.Update(). + Where(node.PublisherIDEQ(pub.ID)). + SetStatus(schema.NodeStatusBanned). + Exec(ctx) + if err != nil { + return fmt.Errorf("fail to update users: %w", err) + } + + nodes, err := tx.Node.Query().Where(node.PublisherID(id)).All(ctx) + if len(nodes) == 0 || ent.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("fail to update nodes: %w", err) + } + + err = s.algolia.IndexNodes(ctx, nodes...) + if err != nil { + return fmt.Errorf("failed to index node: %w", err) + } + + return nil + }) + + return err +} + +func (s *RegistryService) BanNode(ctx context.Context, client *ent.Client, publisherid, id string) error { + log.Ctx(ctx).Info().Msgf("banning publisher node: %v %v", publisherid, id) + + return db.WithTx(ctx, client, func(tx *ent.Tx) error { + n, err := tx.Node.Query().Where(node.And( + node.IDEQ(id), + node.PublisherIDEQ(publisherid), + )).Only(ctx) + if ent.IsNotFound(err) { + return nil + } + + n, err = n.Update(). + SetStatus(schema.NodeStatusBanned). + Save(ctx) + if ent.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("fail to ban node: %w", err) + } + + err = s.algolia.IndexNodes(ctx, n) + if err != nil { + return fmt.Errorf("failed to index node: %w", err) + } + + return err + }) + +} + +func (s *RegistryService) AssertNodeBanned(ctx context.Context, client *ent.Client, nodeID string) error { + node, err := client.Node.Get(ctx, nodeID) + if ent.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("failed to get node: %w", err) + } + if node.Status == schema.NodeStatusBanned { + return newErrorPermission("node '%s' is currently banned", nodeID) + } + return nil +} + +func (s *RegistryService) AssertPublisherBanned(ctx context.Context, client *ent.Client, publisherID string) error { + publisher, err := client.Publisher.Get(ctx, publisherID) + if ent.IsNotFound(err) { + return nil + } + if err != nil { + return fmt.Errorf("failed to get node: %w", err) + } + if publisher.Status == schema.PublisherStatusTypeBanned { + return newErrorPermission("node '%s' is currently banned", publisherID) + } + return nil +} + +func (s *RegistryService) ReindexAllNodes(ctx context.Context, client *ent.Client) error { + log.Ctx(ctx).Info().Msgf("reindexing nodes") + nodes, err := client.Node.Query().All(ctx) + if err != nil { + return fmt.Errorf("failed to fetch all nodes: %w", err) + } + + log.Ctx(ctx).Info().Msgf("reindexing %d number of nodes", len(nodes)) + err = s.algolia.IndexNodes(ctx, nodes...) + if err != nil { + return fmt.Errorf("failed to reindex all nodes: %w", err) + } + return nil +} + +func (s *RegistryService) PerformSecurityCheck(ctx context.Context, client *ent.Client, nodeVersion *ent.NodeVersion) error { + log.Ctx(ctx).Info().Msgf("scanning node %s@%s", nodeVersion.NodeID, nodeVersion.Version) + + if (nodeVersion.Edges.StorageFile == nil) || (nodeVersion.Edges.StorageFile.FileURL == "") { + return fmt.Errorf("node version %s@%s does not have a storage file", nodeVersion.NodeID, nodeVersion.Version) + } + + issues, err := sendScanRequest(s.config.SecretScannerURL, nodeVersion.Edges.StorageFile.FileURL) + if err != nil { + if strings.Contains(err.Error(), "404") { + log.Ctx(ctx).Info().Msgf("Node zip file doesn’t exist %s@%s. Updating to deleted.", nodeVersion.NodeID, nodeVersion.Version) + err := nodeVersion.Update().SetStatus(schema.NodeVersionStatusDeleted).SetStatusReason("Node zip file doesn’t exist").Exec(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msgf("failed to update node version status to active") + } + } + return err + } + + if issues != "" { + log.Ctx(ctx).Info().Msgf("No security issues found in node %s@%s. Updating to active.", nodeVersion.NodeID, nodeVersion.Version) + err := nodeVersion.Update().SetStatus(schema.NodeVersionStatusActive).SetStatusReason("Passed automated checks").Exec(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msgf("failed to update node version status to active") + } + err = s.discordService.SendSecurityCouncilMessage(fmt.Sprintf("Node %s@%s has passed automated scans. Changing status to active.", nodeVersion.NodeID, nodeVersion.Version)) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msgf("failed to send message to discord") + } + } else { + log.Ctx(ctx).Info().Msgf("Security issues found in node %s@%s. Updating to flagged.", nodeVersion.NodeID, nodeVersion.Version) + err := nodeVersion.Update().SetStatus(schema.NodeVersionStatusFlagged).SetStatusReason(issues).Exec(ctx) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msgf("failed to update node version status to security issue") + } + err = s.discordService.SendSecurityCouncilMessage(fmt.Sprintf("Security issues were found in node %s@%s. Status is flagged. Please check it here: https://registry.comfy.org/admin/nodes/%s/versions/%s", nodeVersion.NodeID, nodeVersion.Version, nodeVersion.NodeID, nodeVersion.Version)) + if err != nil { + log.Ctx(ctx).Error().Err(err).Msgf("failed to send message to discord") + } + } + return nil +} + +type ScanRequest struct { + URL string `json:"url"` +} + +func sendScanRequest(apiURL, fileURL string) (string, error) { + requestBody, err := json.Marshal(ScanRequest{URL: fileURL}) + if err != nil { + return "", fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(requestBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response body: %w", err) + } + + fmt.Println("Response Status:", resp.Status) + if resp.StatusCode != 200 { + return "", fmt.Errorf("failed to scan file: %s", responseBody) + } + + return string(responseBody), nil +}