From e3278fc085ddcae1aafaae2db94cc1cec5da0dd3 Mon Sep 17 00:00:00 2001 From: Thanawan Atchariyachanvanit Date: Thu, 7 Sep 2023 22:32:59 +0000 Subject: [PATCH] Test output function Signed-off-by: Thanawan Atchariyachanvanit --- .../fix_hkunlp_instructor-large.ipynb | 655 ++++++++++++++++-- 1 file changed, 602 insertions(+), 53 deletions(-) diff --git a/docs/source/examples/fix_hkunlp_instructor-large.ipynb b/docs/source/examples/fix_hkunlp_instructor-large.ipynb index c92a5264..09ec21f2 100644 --- a/docs/source/examples/fix_hkunlp_instructor-large.ipynb +++ b/docs/source/examples/fix_hkunlp_instructor-large.ipynb @@ -2,10 +2,19 @@ "cells": [ { "cell_type": "code", - "execution_count": 24, + "execution_count": 1, "id": "63385cb0", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import json\n", "import os\n", @@ -19,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "id": "55b82d29", "metadata": {}, "outputs": [], @@ -32,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 4, "id": "e2edef77", "metadata": {}, "outputs": [ @@ -51,7 +60,7 @@ "'sentence-transformer-torchscript/instructor-large.zip'" ] }, - "execution_count": 12, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -132,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 5, "id": "11cb24c5", "metadata": {}, "outputs": [ @@ -149,7 +158,7 @@ "'sentence-transformer-torchscript/ml-commons_model_config.json'" ] }, - "execution_count": 25, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -320,7 +329,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "eb7bd14e", "metadata": {}, "outputs": [], @@ -346,31 +355,10 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 7, "id": "45cc5394", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "model file is saved to sentence-transformer-torchscript/instructor-large.pt\n", - "zip file is saved to sentence-transformer-torchscript/instructor-large.zip \n", - "\n", - "ml-commons_model_config.json file is saved at : sentence-transformer-torchscript/ml-commons_model_config.json\n" - ] - }, - { - "data": { - "text/plain": [ - "'sentence-transformer-torchscript/ml-commons_model_config.json'" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# import json\n", "# import os\n", @@ -638,7 +626,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 10, "id": "87f9b88d", "metadata": {}, "outputs": [ @@ -696,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 12, "id": "b7e6d0f2", "metadata": {}, "outputs": [ @@ -705,8 +693,8 @@ "output_type": "stream", "text": [ "Total number of chunks 135\n", - "Sha1 value of the model file: 41f576f7235aebd1b9ead05366b3fcb4e6c642ef225c9fbcb1d465356a24c53b\n", - "Model meta data was created successfully. Model Id: OJ-mbIoBx1PaKKd27EH_\n", + "Sha1 value of the model file: eb3be3c144a3564f666b8c19bc5c4e841b4c352422e2150e439e2076a6c5806b\n", + "Model meta data was created successfully. Model Id: _cu7cYoBSA5PdoWsvUMh\n", "uploading chunk 1 of 135\n", "Model id: {'status': 'Uploaded'}\n", "uploading chunk 2 of 135\n", @@ -984,17 +972,17 @@ "name": "stdout", "output_type": "stream", "text": [ - "Task ID: OZ-nbIoBx1PaKKd2qkFw\n", + "Task ID: _su8cYoBSA5PdoWsfUOb\n", "Model deployed successfully\n" ] }, { "data": { "text/plain": [ - "'OJ-mbIoBx1PaKKd27EH_'" + "'_cu7cYoBSA5PdoWsvUMh'" ] }, - "execution_count": 29, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1005,46 +993,607 @@ }, { "cell_type": "code", - "execution_count": 30, - "id": "d4444e3a", + "execution_count": null, + "id": "704983a7", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0965859c", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "190ef986", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "c305ddb6", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", - "input_sentences = [\"first sentence\", \"second sentence\"]\n", "\n", - "# Generated embedding from torchScript\n", - "\n", - "embedding_output_torch = ml_client.generate_embedding(\"OJ-mbIoBx1PaKKd27EH_\", input_sentences)" + "sentence = \"3D ActionSLAM: wearable person tracking in multi-floor environments\"\n", + "instruction = \"Represent the Science title:\"\n", + "input_sentences = [[sentence, instruction]]\n" ] }, { "cell_type": "code", - "execution_count": 31, - "id": "85d0f4b0", + "execution_count": 42, + "id": "d4444e3a", + "metadata": {}, + "outputs": [], + "source": [ + "embedding_output_torch = ml_client.generate_embedding(\"ZE6tcYoBmvm98tEoUR17\", input_sentences)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "73b75166", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "768" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(embedding_output_torch['inference_results'][0]['output'][0]['data'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07e0bab8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a06b4c8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bdf236a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd461081", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3f8e1490", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "4f4741d0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "0\n", - "None\n", - "1\n", - "None\n" + "load INSTRUCTOR_Transformer\n", + "max_seq_length 512\n", + "[[-6.15552627e-02 1.04199909e-02 5.88440849e-03 1.93768777e-02\n", + " 5.71417809e-02 2.57655680e-02 -4.01848811e-05 -2.80044544e-02\n", + " -2.92965453e-02 4.91884835e-02 6.78200126e-02 2.18692459e-02\n", + " 4.54528630e-02 1.50187174e-02 -4.84451912e-02 -3.25259753e-02\n", + " -3.56492661e-02 1.19935432e-02 -6.83915243e-03 3.03126313e-02\n", + " 5.17491624e-02 3.48140560e-02 4.91032703e-03 6.68928474e-02\n", + " 1.52824204e-02 3.54217105e-02 1.07743740e-02 6.89828917e-02\n", + " 4.44019511e-02 -3.23419496e-02 1.24267889e-02 -2.15528179e-02\n", + " -1.62690915e-02 -4.15058285e-02 -2.42290483e-03 -3.07159079e-03\n", + " 4.27047350e-02 1.56428684e-02 2.57813130e-02 5.92843294e-02\n", + " -1.99174136e-02 1.32361799e-02 1.08407950e-02 -4.00610566e-02\n", + " -1.36212725e-03 -1.57032814e-02 -2.53812242e-02 -1.31972805e-02\n", + " -7.83779379e-03 -1.14009008e-02 -4.82025407e-02 -2.58416273e-02\n", + " -4.98771109e-03 4.98239547e-02 1.19490065e-02 -5.55060543e-02\n", + " -2.82120351e-02 -3.32208872e-02 2.46765111e-02 -5.66114485e-02\n", + " -5.12201386e-03 1.95142869e-02 -2.12629829e-02 1.92354042e-02\n", + " 2.46065073e-02 -4.58347723e-02 3.27664278e-02 -3.99055742e-02\n", + " 5.31269349e-02 9.05527559e-04 4.53844778e-02 -2.51501352e-02\n", + " 1.74823881e-03 -9.64769274e-02 -9.51786060e-03 -6.47392124e-03\n", + " 3.51561382e-02 3.58432494e-02 -5.11278324e-02 4.30903099e-02\n", + " 4.58191633e-02 1.91871580e-02 2.38421671e-02 -1.71816293e-02\n", + " -1.52623244e-02 5.40182367e-02 -5.58874011e-02 4.29563001e-02\n", + " 8.48113280e-03 7.83620495e-03 -3.27342823e-02 -1.08465450e-02\n", + " -7.19641568e-03 -4.37382981e-02 -1.88113526e-02 5.16907349e-02\n", + " 4.62869145e-02 -2.63639893e-02 3.73640880e-02 1.84657965e-02\n", + " 5.99115565e-02 1.80141302e-04 -2.35873796e-02 5.71749285e-02\n", + " 1.20532736e-02 -3.81674580e-02 -3.55241075e-02 2.34813849e-03\n", + " -4.45777997e-02 9.34025273e-03 5.85195236e-03 -3.56189236e-02\n", + " -2.23838575e-02 -1.38210715e-03 8.74637067e-03 2.08802372e-02\n", + " 7.03728944e-02 -4.39637005e-02 -4.53046709e-02 -4.76960503e-02\n", + " 4.33718599e-02 -1.97182014e-03 -5.65527752e-03 -2.16748025e-02\n", + " -7.46926218e-02 1.90407708e-02 -2.33457312e-02 -5.68974577e-02\n", + " -9.49267950e-03 4.25820984e-03 3.14501813e-03 1.90789737e-02\n", + " -1.00614019e-02 -6.33771420e-02 4.90878969e-02 2.97248526e-03\n", + " -7.01222867e-02 1.71163045e-02 1.05466843e-02 8.59851614e-02\n", + " -5.78762367e-02 -3.88501137e-02 4.20247996e-03 -1.92795359e-02\n", + " -4.11053002e-02 7.98566174e-03 4.75644283e-02 -4.87977602e-02\n", + " -3.62160131e-02 -2.10572612e-02 4.02226932e-02 -4.74730358e-02\n", + " -2.78858747e-02 8.39250907e-02 -9.76029597e-03 2.62570437e-02\n", + " -5.60530759e-02 1.52837224e-02 1.54583401e-03 2.02960498e-03\n", + " -3.28001268e-02 5.76916039e-02 -7.33235553e-02 -4.00819927e-02\n", + " -3.98107953e-02 -3.84523645e-02 -8.67155753e-03 1.05411708e-01\n", + " -2.86331237e-03 -1.91161316e-02 -5.60036413e-02 9.67338309e-03\n", + " 5.51291034e-02 2.56364211e-03 -2.94723455e-02 5.84518462e-02\n", + " 5.15934229e-02 -1.61305186e-03 -2.19461825e-02 5.65167554e-02\n", + " 4.74953279e-02 -2.44090706e-02 -2.66008992e-02 -5.86746773e-03\n", + " 2.24451218e-02 -2.23603705e-03 4.56711045e-03 3.27842422e-02\n", + " 5.26623288e-03 -2.01674551e-02 -2.33967975e-02 4.43987399e-02\n", + " -1.51708275e-02 7.38917291e-03 2.71087196e-02 -2.46057920e-02\n", + " -1.87857188e-02 -5.61464461e-04 -3.28655392e-02 -1.21782236e-02\n", + " 1.79727422e-03 -1.50850788e-02 2.52194256e-02 1.25257755e-02\n", + " -2.65359791e-04 1.23138353e-02 -6.45002862e-03 1.02272674e-01\n", + " -2.98037715e-02 5.94182312e-02 -2.78096017e-03 -3.49573679e-02\n", + " 3.06671727e-02 5.42211048e-02 5.95246293e-02 4.14741263e-02\n", + " -4.06689895e-03 -3.94712463e-02 1.96131431e-02 5.96131235e-02\n", + " 4.44265865e-02 4.40843925e-02 -5.12231402e-02 -3.00020408e-02\n", + " 3.01150158e-02 2.40173973e-02 -3.39305885e-02 -1.70434006e-02\n", + " 8.32551345e-03 2.66083386e-02 7.67713366e-03 1.76458433e-02\n", + " -2.06325063e-03 1.77012943e-02 -6.08421750e-02 -7.96776712e-02\n", + " 4.99934442e-02 2.96638533e-02 -4.47009411e-03 1.65794324e-02\n", + " -2.35370398e-02 -3.23977438e-03 2.61382628e-02 -1.34953307e-02\n", + " -1.60201844e-02 -1.08793685e-02 -1.77004971e-02 -6.53111422e-03\n", + " 6.91719949e-02 -4.63659726e-02 4.15586568e-02 1.24583598e-02\n", + " -1.88725520e-04 2.47693639e-02 -3.62277292e-02 5.47523722e-02\n", + " 1.54009983e-01 6.00456586e-03 -2.70665511e-02 4.70894761e-02\n", + " 4.09195684e-02 4.31693867e-02 6.22434355e-02 -2.51828600e-02\n", + " 6.71826899e-02 1.89108830e-02 3.67507823e-02 7.62735754e-02\n", + " 5.01050672e-04 -7.33284000e-03 1.95556283e-02 8.43793601e-02\n", + " 1.24929100e-02 -2.75657373e-03 4.97817136e-02 -1.73069611e-02\n", + " 2.77005229e-02 -2.63486300e-02 -2.21686810e-02 3.95561103e-03\n", + " -9.68612079e-03 3.96470763e-02 -8.72506853e-03 -1.07546160e-02\n", + " -2.70988829e-02 -1.17305303e-02 -1.16984015e-02 4.52318490e-02\n", + " -9.12858080e-03 -1.14591718e-02 8.29536747e-03 -3.94435227e-02\n", + " 8.80732387e-03 -3.67274545e-02 -4.45834659e-02 -2.38478389e-02\n", + " 1.73519570e-02 2.46788189e-02 -9.24503356e-02 3.40846553e-03\n", + " -8.58144388e-02 -1.69283785e-02 8.74705799e-03 -2.66722660e-03\n", + " -3.10086529e-03 -6.62742853e-02 1.74709838e-02 -6.20296970e-02\n", + " -7.71831945e-02 -4.30789776e-02 -6.97872937e-02 -2.76594125e-02\n", + " -7.36039579e-02 2.61303931e-02 4.94785458e-02 1.88994482e-02\n", + " 2.05077250e-02 5.93992881e-03 -2.71200631e-02 -4.64439504e-02\n", + " 2.66322866e-02 2.63824426e-02 3.03617190e-03 -4.70094867e-02\n", + " -8.68524797e-03 -1.94981520e-03 -1.47214429e-02 -3.10322680e-02\n", + " -3.54933180e-02 7.64071718e-02 9.24097970e-02 1.11720497e-02\n", + " 6.86150556e-03 2.67613679e-02 -4.66881432e-02 -4.80801016e-02\n", + " -1.76523291e-02 -5.05446680e-02 -2.54300330e-02 -2.59506684e-02\n", + " -2.86576096e-02 -3.34676728e-02 -3.07256356e-02 6.79465756e-03\n", + " -5.43393195e-02 -2.23254762e-03 1.03654973e-02 3.52348499e-02\n", + " 2.40201559e-02 -2.09923671e-03 -8.59064758e-02 -4.86475974e-02\n", + " 3.41627039e-02 9.51631833e-03 2.42883037e-03 -6.15581088e-02\n", + " -2.23672725e-02 1.49234310e-02 -6.16901880e-03 -2.94565558e-02\n", + " -8.48871283e-03 3.98517437e-02 3.54111418e-02 -1.31471222e-02\n", + " -2.31655948e-02 -2.86290571e-02 1.44813117e-02 -3.19011463e-03\n", + " 5.59895253e-03 -6.02383539e-02 -5.41782789e-02 6.31063944e-03\n", + " -3.27197160e-03 6.00864477e-02 -4.93385196e-02 -1.23744970e-02\n", + " -2.60731447e-02 3.88635322e-02 3.19503173e-02 2.37053372e-02\n", + " -2.05829702e-02 1.42387496e-02 -3.58667374e-02 -3.98508646e-02\n", + " 8.55066627e-03 -2.32857782e-02 1.41011244e-02 5.81302680e-02\n", + " -4.17654496e-03 7.78259803e-03 8.50560889e-02 2.76554506e-02\n", + " 4.23116311e-02 2.45192852e-02 -2.62568891e-02 3.76733541e-02\n", + " -1.03408853e-02 2.60650143e-02 6.19982975e-03 -1.73710976e-02\n", + " -5.55875227e-02 -1.02811225e-01 -8.27026740e-03 -6.74285553e-03\n", + " -5.95137514e-02 1.29434923e-02 4.41101007e-02 -7.02069700e-03\n", + " -3.03075369e-02 -9.03239567e-03 2.10526474e-02 2.01296899e-02\n", + " -3.11781210e-03 4.94987555e-02 -2.36510094e-02 2.80551519e-02\n", + " -2.48605236e-02 5.25815785e-03 -5.47549576e-02 -1.80020817e-02\n", + " -6.72237203e-03 7.68097118e-02 2.41172332e-02 6.28411770e-02\n", + " 4.77913357e-02 -1.15464935e-02 -4.14417088e-02 2.10504960e-02\n", + " 6.09488599e-02 -2.36858018e-02 -3.18970941e-02 2.34901183e-03\n", + " -2.75846943e-03 1.48618768e-03 -4.22429573e-03 5.57198701e-03\n", + " 2.00943667e-02 5.29720783e-02 -3.99871208e-02 -1.41997430e-02\n", + " 3.94999571e-02 -1.47230728e-02 -4.10684012e-03 -6.41633645e-02\n", + " -2.31138412e-02 1.63525681e-03 6.87347678e-03 5.51297814e-02\n", + " 1.13907279e-02 3.55854705e-02 5.87924458e-02 2.42435951e-02\n", + " -3.97643782e-02 -7.16551989e-02 4.69529703e-02 -3.05532152e-03\n", + " -4.91016470e-02 -9.50928256e-02 -1.41104050e-02 2.90550943e-02\n", + " 2.07553767e-02 -2.56224279e-03 -2.63764765e-02 -5.93052991e-03\n", + " 6.81198090e-02 -2.53772512e-02 6.08022697e-02 4.24165688e-02\n", + " 4.66698669e-02 3.79461348e-02 -1.22388816e-02 6.11324497e-02\n", + " -1.82264987e-02 -8.81061703e-03 2.42136922e-02 2.62034535e-02\n", + " -1.55038983e-02 -2.20747292e-02 -5.16002886e-02 2.53373291e-02\n", + " 3.05230375e-02 1.20210228e-02 8.25989991e-02 -2.68187318e-02\n", + " -3.36164124e-02 -3.96278538e-02 2.64574755e-02 -4.73223440e-02\n", + " 5.45928292e-02 4.71893214e-02 5.40369861e-02 -3.63412164e-02\n", + " -4.38812077e-02 -9.25776083e-03 -1.49381906e-02 1.94572601e-02\n", + " -4.68942858e-02 -2.96848789e-02 -6.92514926e-02 2.51878574e-02\n", + " -1.31793972e-02 -3.26385312e-02 -8.38335454e-02 1.62501279e-02\n", + " 5.05868159e-03 -3.85647714e-02 4.18353155e-02 4.50653471e-02\n", + " 4.53344621e-02 3.85494605e-02 5.27763069e-02 9.01090167e-03\n", + " -2.32415143e-02 4.14123312e-02 -3.90885249e-02 -1.84995271e-02\n", + " -2.91617382e-02 -6.02056943e-02 -3.62730958e-02 4.92623029e-03\n", + " -1.51348123e-02 -1.77912544e-02 -6.56068511e-03 3.74852866e-02\n", + " -4.98751830e-03 3.45563330e-02 8.38176627e-03 1.23971188e-02\n", + " 1.30274482e-02 -5.76015376e-02 -1.41846295e-02 -3.29240113e-02\n", + " -6.02640659e-02 -4.08707075e-02 6.09732494e-02 -5.65142743e-03\n", + " -2.64281519e-02 1.45490095e-02 1.39951501e-02 2.01470554e-02\n", + " 1.63883958e-02 -4.30175960e-02 8.81800242e-03 9.79693327e-03\n", + " -4.37083468e-02 -1.07098445e-02 -2.09241901e-02 -1.68447625e-02\n", + " 2.54024155e-02 -4.39964309e-02 2.77971942e-02 2.39688512e-02\n", + " 4.46380768e-03 -4.09839563e-02 1.39753716e-02 -1.02954321e-02\n", + " -4.48161736e-02 1.04085468e-02 -2.32339464e-02 8.22257251e-03\n", + " 1.08463941e-02 -7.12043280e-03 -2.48803925e-02 1.47036817e-02\n", + " -1.03130294e-02 5.29496707e-02 2.34216079e-02 -3.16518173e-02\n", + " 2.24910602e-02 -1.01565402e-02 2.24805549e-02 -6.64025024e-02\n", + " 2.63604522e-02 -2.33393200e-02 2.29447111e-02 -1.88058559e-02\n", + " -2.10313057e-03 -4.88403216e-02 4.41654511e-02 -2.42530257e-02\n", + " -3.33837792e-02 6.30348036e-03 1.08948885e-03 1.65918248e-03\n", + " 1.43814711e-02 -6.16017962e-03 2.33820472e-02 -6.41303658e-02\n", + " 2.14748587e-02 1.68789178e-02 -1.88098792e-02 -1.45088257e-02\n", + " 4.35655639e-02 -3.56806517e-02 -1.71170831e-02 4.00119135e-03\n", + " -1.24642029e-02 3.74952182e-02 3.54862548e-02 2.71979184e-03\n", + " 4.88897450e-02 -1.42481411e-02 -2.37889662e-02 1.45645356e-02\n", + " -5.29264621e-02 -3.16047631e-02 -2.55868081e-02 6.24947716e-04\n", + " 1.23044737e-02 1.52396616e-02 5.92730334e-03 -6.96792677e-02\n", + " -4.38257121e-02 3.32457758e-02 4.29933332e-02 3.41573469e-02\n", + " 5.74664306e-03 6.92842621e-03 2.19891723e-02 5.40519953e-02\n", + " -3.47650945e-02 -6.38604071e-03 -1.06168566e-02 5.59756253e-03\n", + " 2.51517501e-02 1.97777734e-03 -9.76092927e-03 1.29118180e-02\n", + " -5.10915816e-02 -4.22592573e-02 6.32157177e-02 6.68454096e-02\n", + " 3.72742787e-02 -1.31203542e-02 -3.29280011e-02 3.23108919e-02\n", + " 2.64140982e-02 -4.51177172e-02 6.29258081e-02 -3.71046620e-03\n", + " -3.95429395e-02 5.86496964e-02 -5.98639622e-03 1.91448089e-02\n", + " 3.20215262e-02 5.33388630e-02 -2.62014270e-02 2.29458027e-02\n", + " -1.26508866e-02 6.65134517e-03 6.07207976e-02 -3.29924598e-02\n", + " -2.04965305e-02 -5.14310375e-02 6.54849783e-02 -4.22492288e-02\n", + " 9.26519334e-02 1.99730266e-02 -1.83647852e-02 1.88232283e-03\n", + " -4.16838452e-02 -6.10365681e-02 2.76884940e-02 -3.12236119e-02\n", + " 5.57781495e-02 -3.40828001e-02 -5.36403507e-02 3.83231528e-02\n", + " 1.50124645e-02 -6.74923360e-02 6.30307719e-02 1.23501029e-02\n", + " 5.98304309e-02 2.07509398e-02 3.15652266e-02 -3.75223532e-02\n", + " 3.68293971e-02 -6.04589507e-02 9.98122990e-03 -3.74645405e-02\n", + " 4.14430210e-03 2.01168507e-02 -1.38435932e-02 -6.81752991e-03\n", + " 3.87795828e-03 -8.58292729e-03 -1.53929193e-03 4.07204069e-02\n", + " -5.60819432e-02 -6.66264743e-02 2.24500354e-02 1.80751961e-02\n", + " -1.88874244e-03 1.47162657e-02 -2.16904916e-02 8.98237061e-03\n", + " 3.33474614e-02 -1.17769474e-02 -2.53784601e-02 7.49977864e-03\n", + " 1.59923248e-02 -5.08552305e-02 -2.55495775e-02 3.99981141e-02\n", + " 2.41953484e-03 -1.39974179e-02 1.12951621e-02 1.01209211e-03\n", + " -1.35118491e-04 1.50756817e-02 3.66248237e-03 3.56586389e-02\n", + " -2.39739642e-02 -5.98192215e-03 -1.14465142e-02 7.17898458e-03\n", + " -7.58271478e-03 -1.56441070e-02 2.08631977e-02 4.67960909e-02\n", + " 9.20427777e-03 9.97737236e-03 -1.69361867e-02 -3.79866734e-02\n", + " 2.12064050e-02 -3.93209793e-02 2.39590071e-02 4.52188635e-03\n", + " -4.90421988e-02 -2.53686830e-02 -5.13280444e-02 -3.11175715e-02\n", + " 3.69731188e-02 -3.32236588e-02 2.64319666e-02 3.13280262e-02\n", + " 5.02048135e-02 -3.51830311e-02 -9.47780237e-02 -3.81823629e-02\n", + " -2.52813045e-02 8.34161229e-03 1.06830755e-02 -2.85213180e-02\n", + " 1.14974445e-02 -1.63908321e-02 -5.35691381e-02 1.44921457e-02\n", + " 1.44734280e-02 1.46977436e-02 3.46375965e-02 4.89880778e-02\n", + " -2.99605671e-02 7.35144131e-03 1.49103850e-02 -2.81755980e-02\n", + " 4.02826704e-02 1.23249050e-02 2.03392170e-02 4.75965329e-02\n", + " 4.34238613e-02 8.02049506e-03 -1.76124519e-03 -6.28188923e-02\n", + " -4.67858873e-02 -3.76614109e-02 1.02270264e-02 4.39473838e-02]]\n" ] } ], + "source": [ + "from InstructorEmbedding import INSTRUCTOR\n", + "model = INSTRUCTOR('hkunlp/instructor-large')\n", + "sentence = \"3D ActionSLAM: wearable person tracking in multi-floor environments\"\n", + "instruction = \"Represent the Science title:\"\n", + "embeddings = model.encode([[instruction,sentence]])\n", + "print(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "d6f090d3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1, 768)" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "0ca9bc08", + "metadata": {}, + "outputs": [], "source": [ "original_pre_trained_model = model\n", "original_embedding_data = list(\n", " original_pre_trained_model.encode(input_sentences, convert_to_numpy=True)\n", - ")\n", - " \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "edebcd6c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "\nNot equal to tolerance rtol=0.001, atol=1e-05\n\nMismatched elements: 766 / 768 (99.7%)\nMax absolute difference: 0.0982903\nMax relative difference: 6898.21135732\n x: array([-4.475518e-02, -3.724144e-04, 6.277785e-03, 2.341405e-02,\n 4.438463e-02, 1.630010e-02, -1.171885e-02, -5.577739e-03,\n -3.750417e-02, 2.511203e-02, 4.287542e-02, 6.856018e-03,...\n y: array([-5.717582e-02, 2.493422e-03, -3.901535e-03, 1.850889e-02,\n 5.135028e-02, 2.320276e-02, 6.400365e-03, -1.590035e-02,\n -2.774814e-02, 4.663043e-02, 6.030675e-02, 2.010986e-02,...", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[27], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(input_sentences)):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(i)\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtesting\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43massert_allclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43membeddings\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membedding_output_torch\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43minference_results\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43moutput\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mdata\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrtol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-03\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43matol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-05\u001b[39;49m\u001b[43m)\u001b[49m)\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/contextlib.py:75\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/numpy/testing/_private/utils.py:862\u001b[0m, in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf, strict)\u001b[0m\n\u001b[1;32m 858\u001b[0m err_msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(remarks)\n\u001b[1;32m 859\u001b[0m msg \u001b[38;5;241m=\u001b[39m build_err_msg([ox, oy], err_msg,\n\u001b[1;32m 860\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose, header\u001b[38;5;241m=\u001b[39mheader,\n\u001b[1;32m 861\u001b[0m names\u001b[38;5;241m=\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m'\u001b[39m), precision\u001b[38;5;241m=\u001b[39mprecision)\n\u001b[0;32m--> 862\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(msg)\n\u001b[1;32m 863\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[1;32m 864\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtraceback\u001b[39;00m\n", + "\u001b[0;31mAssertionError\u001b[0m: \nNot equal to tolerance rtol=0.001, atol=1e-05\n\nMismatched elements: 766 / 768 (99.7%)\nMax absolute difference: 0.0982903\nMax relative difference: 6898.21135732\n x: array([-4.475518e-02, -3.724144e-04, 6.277785e-03, 2.341405e-02,\n 4.438463e-02, 1.630010e-02, -1.171885e-02, -5.577739e-03,\n -3.750417e-02, 2.511203e-02, 4.287542e-02, 6.856018e-03,...\n y: array([-5.717582e-02, 2.493422e-03, -3.901535e-03, 1.850889e-02,\n 5.135028e-02, 2.320276e-02, 6.400365e-03, -1.590035e-02,\n -2.774814e-02, 4.663043e-02, 6.030675e-02, 2.010986e-02,..." + ] + } + ], + "source": [ + "for i in range(len(input_sentences)):\n", + " print(i)\n", + " print(np.testing.assert_allclose(embeddings[i], embedding_output_torch['inference_results'][i]['output'][0]['data'], rtol=1e-03, atol=1e-05))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "386cabd4", + "metadata": {}, + "outputs": [], + "source": [ + "loaded_model = torch.jit.load(\"traced_sentence_transformer.pt\")\n", + "loaded_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b75aa69f", + "metadata": {}, + "outputs": [], + "source": [ + "POST /_plugins/_ml/_predict/text_embedding/{model_id}\n", + "{\n", + " \"text_docs\": [[question A, answer A], [question B, answer B], [question C, answer C]],\n", + " \"return_number\": true,\n", + " \"target_response\": [ \"sentence_embedding\" ]\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "eb01e136", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ml_client._client" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "91a53b42", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['3D ActionSLAM: wearable person tracking in multi-floor environments',\n", + " 'Represent the Science title:']]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a2bc7da0", + "metadata": {}, + "outputs": [], + "source": [ + "API_BODY = {\"text_docs\": input_sentences}" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "71aa90cf", + "metadata": {}, + "outputs": [ + { + "ename": "TransportError", + "evalue": "TransportError(500, 'null_pointer_exception', 'Cannot invoke \"org.opensearch.ml.common.input.MLInput.setAlgorithm(org.opensearch.ml.common.FunctionName)\" because \"mlInput\" is null')", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTransportError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[24], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mml_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransport\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mperform_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mPOST\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m/_plugins/_ml/_predict/text_embedding/_cu7cYoBSA5PdoWsvUMh\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mAPI_BODY\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/opensearchpy/transport.py:409\u001b[0m, in \u001b[0;36mTransport.perform_request\u001b[0;34m(self, method, url, headers, params, body)\u001b[0m\n\u001b[1;32m 407\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 409\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 411\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 412\u001b[0m \u001b[38;5;66;03m# connection didn't fail, confirm it's live status\u001b[39;00m\n\u001b[1;32m 413\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconnection_pool\u001b[38;5;241m.\u001b[39mmark_live(connection)\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/opensearchpy/transport.py:370\u001b[0m, in \u001b[0;36mTransport.perform_request\u001b[0;34m(self, method, url, headers, params, body)\u001b[0m\n\u001b[1;32m 367\u001b[0m connection \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_connection()\n\u001b[1;32m 369\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 370\u001b[0m status, headers_response, data \u001b[38;5;241m=\u001b[39m \u001b[43mconnection\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mperform_request\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 371\u001b[0m \u001b[43m \u001b[49m\u001b[43mmethod\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 372\u001b[0m \u001b[43m \u001b[49m\u001b[43murl\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 373\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 374\u001b[0m \u001b[43m \u001b[49m\u001b[43mbody\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 375\u001b[0m \u001b[43m \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 376\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 377\u001b[0m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 378\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 380\u001b[0m \u001b[38;5;66;03m# Lowercase all the header names for consistency in accessing them.\u001b[39;00m\n\u001b[1;32m 381\u001b[0m headers_response \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 382\u001b[0m header\u001b[38;5;241m.\u001b[39mlower(): value \u001b[38;5;28;01mfor\u001b[39;00m header, value \u001b[38;5;129;01min\u001b[39;00m headers_response\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 383\u001b[0m }\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/opensearchpy/connection/http_urllib3.py:266\u001b[0m, in \u001b[0;36mUrllib3HttpConnection.perform_request\u001b[0;34m(self, method, url, params, body, timeout, ignore, headers)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;241m200\u001b[39m \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m response\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m300\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m response\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ignore:\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog_request_fail(\n\u001b[1;32m 264\u001b[0m method, full_url, url, orig_body, duration, response\u001b[38;5;241m.\u001b[39mstatus, raw_data\n\u001b[1;32m 265\u001b[0m )\n\u001b[0;32m--> 266\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_raise_error\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[43m \u001b[49m\u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstatus\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mraw_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_response_headers\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcontent-type\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog_request_success(\n\u001b[1;32m 273\u001b[0m method, full_url, url, orig_body, response\u001b[38;5;241m.\u001b[39mstatus, raw_data, duration\n\u001b[1;32m 274\u001b[0m )\n\u001b[1;32m 276\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m response\u001b[38;5;241m.\u001b[39mstatus, response\u001b[38;5;241m.\u001b[39mheaders, raw_data\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/opensearchpy/connection/base.py:301\u001b[0m, in \u001b[0;36mConnection._raise_error\u001b[0;34m(self, status_code, raw_data, content_type)\u001b[0m\n\u001b[1;32m 298\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mValueError\u001b[39;00m, \u001b[38;5;167;01mTypeError\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[1;32m 299\u001b[0m logger\u001b[38;5;241m.\u001b[39mwarning(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUndecodable raw error response from server: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, err)\n\u001b[0;32m--> 301\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m HTTP_EXCEPTIONS\u001b[38;5;241m.\u001b[39mget(status_code, TransportError)(\n\u001b[1;32m 302\u001b[0m status_code, error_message, additional_info\n\u001b[1;32m 303\u001b[0m )\n", + "\u001b[0;31mTransportError\u001b[0m: TransportError(500, 'null_pointer_exception', 'Cannot invoke \"org.opensearch.ml.common.input.MLInput.setAlgorithm(org.opensearch.ml.common.FunctionName)\" because \"mlInput\" is null')" + ] + } + ], + "source": [ + "ml_client._client.transport.perform_request(\n", + " method=\"POST\",\n", + " url=f\"/_plugins/_ml/_predict/text_embedding/_cu7cYoBSA5PdoWsvUMh\",\n", + " body=API_BODY,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "a0deb719", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SentenceTransformer(\n", + " (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: T5EncoderModel \n", + " (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n", + " (2): Dense({'in_features': 1024, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})\n", + " (3): Normalize()\n", + ")" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "0282832e", + "metadata": {}, + "outputs": [], + "source": [ + "hg_model = SentenceTransformer(model_id, cache_folder=\"cache_folder\")" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "58e8d133", + "metadata": {}, + "outputs": [], + "source": [ + "hg_embed = hg_model.encode(input_sentences)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "7e5cebea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[['3D ActionSLAM: wearable person tracking in multi-floor environments',\n", + " 'Represent the Science title:']]" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_sentences" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "519c7e5c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "\nNot equal to tolerance rtol=0.001, atol=1e-05\n\nMismatched elements: 765 / 768 (99.6%)\nMax absolute difference: 0.02885431\nMax relative difference: 285.67206\n x: array([[-6.155526e-02, 1.041999e-02, 5.884408e-03, 1.937688e-02,\n 5.714178e-02, 2.576557e-02, -4.018488e-05, -2.800445e-02,\n -2.929655e-02, 4.918848e-02, 6.782001e-02, 2.186925e-02,...\n y: array([[-6.028877e-02, 2.991482e-03, -4.056946e-04, 2.204840e-02,\n 5.370862e-02, 1.982622e-02, -3.396538e-04, -1.533841e-02,\n -3.740350e-02, 4.192707e-02, 6.052603e-02, 1.625427e-02,...", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[44], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(input_sentences)):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(i)\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtesting\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43massert_allclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43membeddings\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhg_embed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrtol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-03\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43matol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-05\u001b[39;49m\u001b[43m)\u001b[49m)\n", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/contextlib.py:75\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 75\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m/home/linuxbrew/.linuxbrew/opt/python@3.8/lib/python3.8/site-packages/numpy/testing/_private/utils.py:862\u001b[0m, in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf, strict)\u001b[0m\n\u001b[1;32m 858\u001b[0m err_msg \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(remarks)\n\u001b[1;32m 859\u001b[0m msg \u001b[38;5;241m=\u001b[39m build_err_msg([ox, oy], err_msg,\n\u001b[1;32m 860\u001b[0m verbose\u001b[38;5;241m=\u001b[39mverbose, header\u001b[38;5;241m=\u001b[39mheader,\n\u001b[1;32m 861\u001b[0m names\u001b[38;5;241m=\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m'\u001b[39m), precision\u001b[38;5;241m=\u001b[39mprecision)\n\u001b[0;32m--> 862\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAssertionError\u001b[39;00m(msg)\n\u001b[1;32m 863\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[1;32m 864\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtraceback\u001b[39;00m\n", + "\u001b[0;31mAssertionError\u001b[0m: \nNot equal to tolerance rtol=0.001, atol=1e-05\n\nMismatched elements: 765 / 768 (99.6%)\nMax absolute difference: 0.02885431\nMax relative difference: 285.67206\n x: array([[-6.155526e-02, 1.041999e-02, 5.884408e-03, 1.937688e-02,\n 5.714178e-02, 2.576557e-02, -4.018488e-05, -2.800445e-02,\n -2.929655e-02, 4.918848e-02, 6.782001e-02, 2.186925e-02,...\n y: array([[-6.028877e-02, 2.991482e-03, -4.056946e-04, 2.204840e-02,\n 5.370862e-02, 1.982622e-02, -3.396538e-04, -1.533841e-02,\n -3.740350e-02, 4.192707e-02, 6.052603e-02, 1.625427e-02,..." + ] + } + ], + "source": [ "for i in range(len(input_sentences)):\n", " print(i)\n", - " print(np.testing.assert_allclose(original_embedding_data[i], embedding_output_torch['inference_results'][i]['output'][0]['data'], rtol=1e-03, atol=1e-05))" + " print(np.testing.assert_allclose(embeddings, hg_embed, rtol=1e-03, atol=1e-05))" ] } ],