Skip to content

Commit

Permalink
support Numo::NArray
Browse files Browse the repository at this point in the history
`narray` is no longer under development. So added a option for a newer version  `numo-narray` can be used.
  • Loading branch information
yagince committed Jan 10, 2023
1 parent ebe8431 commit c4438a9
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 9 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ env:
- MATRIX_LIBRARY=narray
- MATRIX_LIBRARY=nmatrix
- MATRIX_LIBRARY=matrix
- MATRIX_LIBRARY=numo
addons:
apt:
packages:
Expand Down
1 change: 1 addition & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ source 'https://rubygems.org'
gem 'rb-gsl', '~> 1.16.0.2' if ENV['MATRIX_LIBRARY'] == 'gsl'
gem 'narray', '~> 0.6.0.0' if ENV['MATRIX_LIBRARY'] == 'narray'
gem 'nmatrix', '~> 0.2' if ENV['MATRIX_LIBRARY'] == 'nmatrix'
gem 'numo-narray', '~> 0.9.2.1' if ENV['MATRIX_LIBRARY'] == 'numo'

# Specify your gem's dependencies in the gemspec
gemspec
24 changes: 18 additions & 6 deletions lib/tf-idf-similarity/matrix_methods.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def normalize
norm = NMath.sqrt((@matrix ** 2).sum(1).reshape(@matrix.shape[0], 1))
norm[norm.where2[1]] = 1.0 # avoid division by zero
NMatrix.refer(@matrix / norm) # must be NMatrix for matrix multiplication
when :numo
norm = Numo::NMath.sqrt((@matrix ** 2).sum(0).reshape(1, @matrix.shape[0]))
norm[(norm.eq 0).where] = 1.0 # avoid division by zero
(@matrix / norm)
when :nmatrix # @see https://github.com/SciRuby/nmatrix/issues/38
normal = NMatrix.new(:dense, @matrix.shape, 0, :float64)
(0...@matrix.shape[1]).each do |j|
Expand Down Expand Up @@ -44,7 +48,7 @@ def normalize
# @param [Integer] column index
def get(i, j)
case @library
when :narray
when :narray, :numo
@matrix[j, i]
else
@matrix[i, j]
Expand All @@ -57,6 +61,8 @@ def row(index)
case @library
when :narray
@matrix[true, index]
when :numo
@matrix[index, true]
else
@matrix.row(index)
end
Expand All @@ -66,7 +72,7 @@ def row(index)
# @return [GSL::Vector::View,NArray,NMatrix,Vector] a column
def column(index)
case @library
when :narray
when :narray, :numo
@matrix[index, true]
else
@matrix.column(index)
Expand All @@ -78,7 +84,7 @@ def row_size
case @library
when :gsl, :nmatrix
@matrix.shape[0]
when :narray
when :narray, :numo
@matrix.shape[1]
else
@matrix.row_size
Expand All @@ -90,7 +96,7 @@ def column_size
case @library
when :gsl, :nmatrix
@matrix.shape[1]
when :narray
when :narray, :numo
@matrix.shape[0]
else
@matrix.column_size
Expand All @@ -110,7 +116,7 @@ def values
# @return [Float] the sum of all values in the matrix
def sum
case @library
when :narray
when :narray, :numo
@matrix.sum
else
values.reduce(0, :+)
Expand All @@ -125,6 +131,8 @@ def initialize_matrix(array)
GSL::Matrix[*array]
when :narray
NArray[*array]
when :numo
Numo::DFloat[*array]
when :nmatrix # @see https://github.com/SciRuby/nmatrix/issues/91#issuecomment-18870619
NMatrix.new(:dense, [array.size, array.empty? ? 0 : array[0].size], array.flatten, :float64)
else
Expand All @@ -136,7 +144,7 @@ def initialize_matrix(array)
# @return [GSL::Matrix,NArray,NMatrix,Matrix] the product
def multiply_self(matrix)
case @library
when :nmatrix
when :nmatrix, :numo
matrix.transpose.dot(matrix)
else
matrix.transpose * matrix
Expand All @@ -149,6 +157,8 @@ def log(number)
GSL::Sf::log(number)
when :narray
NMath.log(number)
when :numo
Numo::NMath.log(number)
else
Math.log(number)
end
Expand All @@ -158,6 +168,8 @@ def sqrt(number)
case @library
when :narray
NMath.sqrt(number)
when :numo
Numo::NMath.sqrt(number)
else
Math.sqrt(number)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/tf-idf-similarity/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def initialize(documents, opts = {})
array = Array.new(terms.size) do |i|
idf = inverse_document_frequency(terms[i])
Array.new(documents.size) do |j|
term_frequency(documents[j], terms[i]) * idf
(term_frequency(documents[j], terms[i]) * idf).to_f
end
end

Expand Down
4 changes: 3 additions & 1 deletion lib/tf-idf-similarity/term_count_model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def document_count(term)
case @library
when :gsl, :narray
row(index).where.size
when :numo
(row(index).ne 0).where.size
when :nmatrix
row(index).each.count(&:nonzero?)
else
Expand All @@ -57,7 +59,7 @@ def term_count(term)
index = terms.index(term)
if index
case @library
when :gsl, :narray
when :gsl, :narray, :numo
row(index).sum
when :nmatrix
row(index).each.reduce(0, :+) # NMatrix's `sum` method is slower
Expand Down
7 changes: 6 additions & 1 deletion spec/bm25_model_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,12 @@ def similarity_matrix_values(model)

describe '#term_frequency_inverse_document_frequency' do
it 'should return negative infinity' do
model.tfidf(document, 'foo').should be_nan
case MATRIX_LIBRARY
when :numo
model.tfidf(document, 'foo').isnan.should eq 1
else
model.tfidf(document, 'foo').should be_nan
end
end
end

Expand Down
2 changes: 2 additions & 0 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
require 'gsl'
when :narray
require 'narray'
when :numo
require 'numo/narray'
when :nmatrix
require 'nmatrix'
else
Expand Down

0 comments on commit c4438a9

Please sign in to comment.