-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathTripletEmbedding.lua
37 lines (31 loc) · 1.43 KB
/
TripletEmbedding.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
--------------------------------------------------------------------------------
-- TripletEmbeddingCriterion
--------------------------------------------------------------------------------
-- Alfredo Canziani, Apr/May 15
--------------------------------------------------------------------------------
local TripletEmbeddingCriterion, parent = torch.class('nn.TripletEmbeddingCriterion', 'nn.Criterion')
function TripletEmbeddingCriterion:__init(alpha)
parent.__init(self)
self.alpha = alpha or 0.2
self.Li = torch.Tensor()
self.gradInput = {}
end
function TripletEmbeddingCriterion:updateOutput(input)
local a = input[1] -- ancor
local p = input[2] -- positive
local n = input[3] -- negative
local N = a:size(1)
self.Li = torch.max(torch.cat(torch.Tensor(N):zero():type(torch.type(a)) , (a - p):norm(2,2):pow(2) - (a - n):norm(2,2):pow(2) + self.alpha, 2), 2)
self.output = self.Li:sum() / N
return self.output
end
function TripletEmbeddingCriterion:updateGradInput(input)
local a = input[1] -- ancor
local p = input[2] -- positive
local n = input[3] -- negative
local N = a:size(1)
self.gradInput[1] = (n - p):cmul(self.Li:gt(0):repeatTensor(1,a:size(2)):type(a:type()) * 2/N)
self.gradInput[2] = (p - a):cmul(self.Li:gt(0):repeatTensor(1,a:size(2)):type(a:type()) * 2/N)
self.gradInput[3] = (a - n):cmul(self.Li:gt(0):repeatTensor(1,a:size(2)):type(a:type()) * 2/N)
return self.gradInput
end