Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 64 additions & 4 deletions uwr_related/embeddings_abx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,23 @@
)
logger = logging.getLogger("zerospeech2021 abx")

class CPCModelNullspace(nn.Module):

def __init__(self,
cpc,
nullspace):

super(CPCModelNullspace, self).__init__()
self.cpc = cpc
self.nullspace = nn.Linear(nullspace.shape[0], nullspace.shape[1], bias=False)
self.nullspace.weight = nn.Parameter(nullspace.T)

def forward(self, batchData, label):
cFeature, encodedData, label = self.cpc(batchData, label)
cFeature = self.nullspace(cFeature)
encodedData = self.nullspace(encodedData)
return cFeature, encodedData, label

def parse_args():
# Run parameters
parser = argparse.ArgumentParser()
Expand All @@ -48,6 +65,13 @@ def parse_args():
help="Pre-trained model architecture ('wav2vec2' [default] or 'cpc').")
parser.add_argument("--path_cpc", type=str, default="/pio/scratch/1/i273233/cpc",
help="Path to the root of cpc repo.")
parser.add_argument("--path_speakers_factorized", type=str, default=None,
help="Path to the factorized matrices. We want to project embeddings on the nullspace")
parser.add_argument("--no_test", action="store_true",
help="Don't compute embeddings for test-* parts of dataset")
parser.add_argument('--gru_level', type=int, default=-1,
help='Hidden level of the LSTM autoregressive model to be taken'
'(default: -1, last layer).')
return parser.parse_args()

def main():
Expand All @@ -61,7 +85,34 @@ def main():
if args.model == "cpc":
sys.path.append(os.path.abspath(args.path_cpc))
from cpc.feature_loader import loadModel, FeatureModule
model = loadModel([args.path_checkpoint])[0]

if args.gru_level is not None and args.gru_level > 0:
updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level)
else:
updateConfig = None

model = loadModel([args.path_checkpoint], updateConfig=updateConfig)[0]

if args.gru_level is not None and args.gru_level > 0:
# Keep hidden units at LSTM layers on sequential batches
model.gAR.keepHidden = True

if args.path_speakers_factorized is not None:
def my_nullspace(At, rcond=None):
ut, st, vht = torch.Tensor.svd(At, some=False,compute_uv=True)
vht=vht.T
Mt, Nt = ut.shape[0], vht.shape[1]
if rcond is None:
rcondt = torch.finfo(st.dtype).eps * max(Mt, Nt)
tolt = torch.max(st) * rcondt
numt= torch.sum(st > tolt, dtype=int)
nullspace = vht[numt:,:].T.cpu().conj()
# nullspace.backward(torch.ones_like(nullspace),retain_graph=True)
return nullspace

first_matrix = torch.load(args.path_speakers_factorized)["cpcCriterion"]["linearSpeakerClassifier.0.weight"]
nullspace = my_nullspace(first_matrix)
model = CPCModelNullspace(model, nullspace)
else:
from fairseq import checkpoint_utils

Expand Down Expand Up @@ -102,7 +153,10 @@ def hook(model, input, output):
if args.model == "cpc":
layer_name = os.path.basename(os.path.dirname(args.path_checkpoint))
layer_names.append(layer_name)
model.gAR.register_forward_hook(get_layer_output(layer_name))
if args.path_speakers_factorized is None:
model.gAR.register_forward_hook(get_layer_output(layer_name))
else:
model.nullspace.register_forward_hook(get_layer_output(layer_name))
else:
for i in range(len(model.encoder.layers)):
layer_name = "layer_{}".format(i)
Expand All @@ -120,6 +174,7 @@ def hook(model, input, output):
phonetic = "phonetic"
datasets_path = os.path.join(args.path_data, phonetic)
datasets = os.listdir(datasets_path)
datasets = [dataset for dataset in datasets if not args.no_test or not dataset.startswith("test")]
print(datasets)

with torch.no_grad():
Expand All @@ -134,8 +189,13 @@ def hook(model, input, output):
x = torch.tensor(x).float().reshape(1,-1).to(device)

if args.model == "cpc":
encodedData = model.gEncoder(x.unsqueeze(1)).permute(0, 2, 1)
output = model.gAR(encodedData)
if args.path_speakers_factorized is None:
encodedData = model.gEncoder(x.unsqueeze(1)).permute(0, 2, 1)
output = model.gAR(encodedData)
else:
encodedData = model.cpc.gEncoder(x.unsqueeze(1)).permute(0, 2, 1)
output = model.cpc.gAR(encodedData)
output = model.nullspace(output)
else:
output = model(x, features_only=True)["x"]

Expand Down
43 changes: 36 additions & 7 deletions uwr_related/eval_abx.sh
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
########## CHANGE THIS ##################
ZEROSPEECH_EVAL_ENV=zerospeech2021 # Where the zerospeech2021-evaluate is installed
CPC_ENV=cpc
FAIRSEQ_ENV=202010-fairseq
#CPC_ENV=cpc
#FAIRSEQ_ENV=202010-fairseq
CPC_ENV=202010-fairseq-c11
FAIRSEQ_ENV=202010-fairseq-c11
CONDA_PATH=/pio/scratch/2/i273233/miniconda3
FAIRSEQ_PATH=/pio/scratch/1/i273233/fairseq
CPC_PATH=/pio/scratch/1/i273233/cpc
#########################################

DATASET_PATH=$1
CHECKPOINT_PATH=$2
OUTPUT_DIR=$3
MODEL_KIND=$4 # Either "wav2vec2" or "cpc"
ORIGINAL_DATASET_PATH=$2
CHECKPOINT_PATH=$3
OUTPUT_DIR=$4
MODEL_KIND=$5 # Either "wav2vec2" or "cpc"

case $MODEL_KIND in
wav2vec2|cpc)
Expand All @@ -21,6 +24,11 @@ case $MODEL_KIND in
;;
esac

SPEAKERS_FACTORIZED_PATH=""
if [[ $# -ge 6 ]]; then
SPEAKERS_FACTORIZED_PATH=$6
fi

results=$OUTPUT_DIR/results
embeddings=$OUTPUT_DIR/embeddings
mkdir -p embeddings
Expand All @@ -36,8 +44,27 @@ fi
conda activate $ENV_TO_ACTIVATE

echo "$FAIRSEQ_PATH/uwr_related/embeddings_abx.py"
python $FAIRSEQ_PATH/uwr_related/embeddings_abx.py $CHECKPOINT_PATH $DATASET_PATH $embeddings --model $MODEL_KIND --path_cpc $CPC_PATH
if [[ $SPEAKERS_FACTORIZED_PATH == "" ]]
then
python $FAIRSEQ_PATH/uwr_related/embeddings_abx.py $CHECKPOINT_PATH $DATASET_PATH $embeddings --model $MODEL_KIND --path_cpc $CPC_PATH --gru_level 2
else
python $FAIRSEQ_PATH/uwr_related/embeddings_abx.py $CHECKPOINT_PATH $DATASET_PATH $embeddings --model $MODEL_KIND --path_cpc $CPC_PATH --path_speakers_factorized $SPEAKERS_FACTORIZED_PATH --gru_level 2
fi

for i in `basename -a $(ls -d $embeddings/*/)`
do
for directory in dev-clean dev-other test-clean test-other
do
for file in `ls $embeddings/$i/phonetic/$directory`
do
filename_no_ext="${file%.*}"
if [[ ! -f "$ORIGINAL_DATASET_PATH/phonetic/$directory/${filename_no_ext}.wav" ]]
then
rm $embeddings/$i/phonetic/$directory/$file
fi
done
done
done

conda activate $ZEROSPEECH_EVAL_ENV

Expand Down Expand Up @@ -68,7 +95,9 @@ EOF
for i in `basename -a $(ls -d $embeddings/*/)`
do
cp $embeddings/$metric.yaml $embeddings/$i/meta.yaml
zerospeech2021-evaluate -j 12 -o $results/$metric/$i --no-lexical --no-syntactic --no-semantic $DATASET_PATH $embeddings/$i
#zerospeech2021-evaluate -j 12 -o $results/$metric/$i --no-lexical --no-syntactic --no-semantic $DATASET_PATH $embeddings/$i
#zerospeech2021-evaluate -j 20 -o $results/$metric/$i --force-cpu --no-lexical --no-syntactic --no-semantic $DATASET_PATH $embeddings/$i
zerospeech2021-evaluate -j 20 -o $results/$metric/$i --no-lexical --no-syntactic --no-semantic $ORIGINAL_DATASET_PATH $embeddings/$i
done
done

Expand Down