Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/classifier.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@
require 'classifier/knn'
require 'classifier/tfidf'
require 'classifier/logistic_regression'
require 'classifier/config'
19 changes: 12 additions & 7 deletions lib/classifier/bayes.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ class Bayes # rubocop:disable Metrics/ClassLength
# @rbs @cached_vocab_size: Integer?
# @rbs @dirty: bool
# @rbs @storage: Storage::Base?
# @rbs @min_word_length: Integer

attr_accessor :storage

# The class can be created with one or more categories, each of which will be
# initialized and given a training method. E.g.,
# b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
# b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam']
# @rbs (*String | Symbol | Array[String | Symbol]) -> void
def initialize(*categories)
# b = Classifier::Bayes.new 'Spam', min_word_length: 1
# @rbs (*String | Symbol | Array[String | Symbol], ?min_word_length: Integer) -> void
def initialize(*categories, min_word_length: Classifier.config.min_word_length)
super()
@categories = {}
categories.flatten.each { |category| @categories[category.prepare_category_name] = {} }
Expand All @@ -39,6 +41,7 @@ def initialize(*categories)
@cached_vocab_size = nil
@dirty = false
@storage = nil
@min_word_length = min_word_length
end

# Trains the classifier with text for a category.
Expand Down Expand Up @@ -76,7 +79,7 @@ def untrain(category = nil, text = nil, **categories)
#
# @rbs (String) -> Hash[String, Float]
def classifications(text)
words = text.word_hash.keys
words = text.word_hash(@min_word_length).keys
synchronize do
training_count = cached_training_count
vocab_size = cached_vocab_size
Expand Down Expand Up @@ -117,7 +120,8 @@ def as_json(_options = nil)
categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
total_words: @total_words,
category_counts: @category_counts.transform_keys(&:to_s),
category_word_count: @category_word_count.transform_keys(&:to_s)
category_word_count: @category_word_count.transform_keys(&:to_s),
min_word_length: @min_word_length
}
end

Expand Down Expand Up @@ -409,7 +413,7 @@ def train_batch_internal(category, batch)
invalidate_caches
@dirty = true
batch.each do |text|
word_hash = text.word_hash
word_hash = text.word_hash(@min_word_length)
@category_counts[category] += 1
word_hash.each do |word, count|
@categories[category][word] ||= 0
Expand All @@ -425,7 +429,7 @@ def train_batch_internal(category, batch)
# @rbs (String | Symbol, String) -> void
def train_single(category, text)
category = category.prepare_category_name
word_hash = text.word_hash
word_hash = text.word_hash(@min_word_length)
synchronize do
invalidate_caches
@dirty = true
Expand All @@ -443,7 +447,7 @@ def train_single(category, text)
# @rbs (String | Symbol, String) -> void
def untrain_single(category, text)
category = category.prepare_category_name
word_hash = text.word_hash
word_hash = text.word_hash(@min_word_length)
synchronize do
invalidate_caches
@dirty = true
Expand Down Expand Up @@ -487,6 +491,7 @@ def restore_state(data)
@cached_vocab_size = nil
@dirty = false
@storage = nil
@min_word_length = data['min_word_length'] || Classifier.config.min_word_length

data['categories'].each do |cat_name, words|
@categories[cat_name.to_sym] = words.transform_keys(&:to_sym)
Expand Down
31 changes: 31 additions & 0 deletions lib/classifier/config.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# rbs_inline: enabled

module Classifier
# @rbs @config: Config?

# This lazy initialization is not thread-safe.
# In multi-threaded environments, ensure this method is called
# or configuration is set explicitly during startup before using classifiers.
# @rbs () -> Config
def config
@config ||= Config.new
Comment thread
Yegorov marked this conversation as resolved.
end

# @rbs () { (Config) -> void } -> void
def configure(&block)
block&.call(config)
end

module_function :config, :configure

class Config
# @rbs @min_word_length: Integer

attr_accessor :min_word_length #: Integer

# @rbs () -> void
def initialize
@min_word_length = 3
end
end
end
18 changes: 9 additions & 9 deletions lib/classifier/extensions/word_hash.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@ def without_punctuation

# Return a Hash of strings => ints. Each word in the string is stemmed,
# interned, and indexes to its frequency in the document.
# @rbs () -> Hash[Symbol, Integer]
def word_hash
word_hash = clean_word_hash
# @rbs (?Integer) -> Hash[Symbol, Integer]
def word_hash(min_word_length = 3)
word_hash = clean_word_hash(min_word_length)
symbol_hash = word_hash_for_symbols(gsub(/\w/, ' ').split)
word_hash.merge(symbol_hash)
end

# Return a word hash without extra punctuation or short symbols, just stemmed words
# @rbs () -> Hash[Symbol, Integer]
def clean_word_hash
word_hash_for_words gsub(/[^\w\s]/, '').split
# @rbs (?Integer) -> Hash[Symbol, Integer]
def clean_word_hash(min_word_length = 3)
word_hash_for_words(gsub(/[^\w\s]/, '').split, min_word_length)
end

private

# @rbs (Array[String]) -> Hash[Symbol, Integer]
def word_hash_for_words(words)
# @rbs (Array[String], Integer) -> Hash[Symbol, Integer]
def word_hash_for_words(words, min_word_length)
d = Hash.new(0)
words.each do |word|
word.downcase!
d[word.stem.intern] += 1 if !CORPUS_SKIP_WORDS.include?(word) && word.length > 2
d[word.stem.intern] += 1 if !CORPUS_SKIP_WORDS.include?(word) && word.length >= min_word_length
end
d
end
Expand Down
28 changes: 18 additions & 10 deletions lib/classifier/logistic_regression.rb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength
# @rbs @fitted: bool
# @rbs @dirty: bool
# @rbs @storage: Storage::Base?
# @rbs @min_word_length: Integer

attr_accessor :storage

Expand All @@ -53,13 +54,16 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength
# - regularization: L2 regularization strength (default: 0.01)
# - max_iterations: Maximum training iterations (default: 100)
# - tolerance: Convergence threshold (default: 1e-4)
# - min_word_length: Minimum word length filter in tokenization
#
# @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float,
# ?max_iterations: Integer, ?tolerance: Float) -> void
# ?max_iterations: Integer, ?tolerance: Float, ?min_word_length: Integer) -> void
# rubocop:disable Metrics/ParameterLists
def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE,
regularization: DEFAULT_REGULARIZATION,
max_iterations: DEFAULT_MAX_ITERATIONS,
tolerance: DEFAULT_TOLERANCE)
tolerance: DEFAULT_TOLERANCE,
min_word_length: Classifier.config.min_word_length)
super()
categories = categories.flatten
@categories = categories.map { |c| c.to_s.prepare_category_name }
Expand All @@ -74,7 +78,9 @@ def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE,
@fitted = false
@dirty = false
@storage = nil
@min_word_length = min_word_length
end
# rubocop:enable Metrics/ParameterLists

# Trains the classifier with text for a category.
#
Expand Down Expand Up @@ -130,7 +136,7 @@ def classify(text)
def probabilities(text)
raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted

features = text.word_hash
features = text.word_hash(@min_word_length)
synchronize do
softmax(compute_scores(features))
end
Expand All @@ -143,7 +149,7 @@ def probabilities(text)
def classifications(text)
raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted

features = text.word_hash
features = text.word_hash(@min_word_length)
synchronize do
compute_scores(features).transform_keys(&:to_s)
end
Expand Down Expand Up @@ -239,7 +245,8 @@ def as_json(_options = nil)
regularization: @regularization,
max_iterations: @max_iterations,
tolerance: @tolerance,
fitted: @fitted
fitted: @fitted,
min_word_length: @min_word_length
}
end

Expand Down Expand Up @@ -336,7 +343,7 @@ def reload!
def marshal_dump
fit unless @fitted
[@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
@max_iterations, @tolerance, @fitted]
@max_iterations, @tolerance, @fitted, @min_word_length]
end

# Custom marshal deserialization to recreate mutex.
Expand All @@ -345,7 +352,7 @@ def marshal_dump
def marshal_load(data)
mu_initialize
@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization,
@max_iterations, @tolerance, @fitted = data
@max_iterations, @tolerance, @fitted, @min_word_length = data
@training_data = []
@dirty = false
@storage = nil
Expand Down Expand Up @@ -395,7 +402,7 @@ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
reader.each_batch do |batch|
synchronize do
batch.each do |text|
features = text.word_hash
features = text.word_hash(@min_word_length)
features.each_key { |word| @vocabulary[word] = true }
@training_data << { category: category, features: features }
end
Expand Down Expand Up @@ -444,7 +451,7 @@ def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT
documents.each_slice(batch_size) do |batch|
synchronize do
batch.each do |text|
features = text.word_hash
features = text.word_hash(@min_word_length)
features.each_key { |word| @vocabulary[word] = true }
@training_data << { category: category, features: features }
end
Expand All @@ -463,7 +470,7 @@ def train_single(category, text)
category = category.to_s.prepare_category_name
raise StandardError, "No such category: #{category}" unless @categories.include?(category)

features = text.word_hash
features = text.word_hash(@min_word_length)
synchronize do
features.each_key { |word| @vocabulary[word] = true }
@training_data << { category: category, features: features }
Expand Down Expand Up @@ -570,6 +577,7 @@ def restore_state(data, categories)
@fitted = data.fetch('fitted', true)
@dirty = false
@storage = nil
@min_word_length = data['min_word_length'] || Classifier.config.min_word_length
end

def restore_weights_and_bias(data)
Expand Down
15 changes: 12 additions & 3 deletions lib/classifier/lsi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class LSI
# @rbs @u_matrix: Matrix?
# @rbs @max_rank: Integer
# @rbs @initial_vocab_size: Integer?
# @rbs @min_word_length: Integer

attr_reader :word_list, :singular_values
attr_accessor :auto_rebuild, :storage
Expand Down Expand Up @@ -110,6 +111,7 @@ def initialize(options = {})
@max_rank = options[:max_rank] || DEFAULT_MAX_RANK
@u_matrix = nil
@initial_vocab_size = nil
@min_word_length = options[:min_word_length] || Classifier.config.min_word_length
end

# Returns true if the index needs to be rebuilt. The index needs
Expand Down Expand Up @@ -216,7 +218,13 @@ def add(**items)
#
# @rbs (String, *String | Symbol) ?{ (String) -> String } -> void
def add_item(item, *categories, &block)
clean_word_hash = block ? block.call(item).clean_word_hash : item.to_s.clean_word_hash
clean_word_hash =
if block
block.call(item).clean_word_hash(@min_word_length)
else
item.to_s.clean_word_hash(@min_word_length)
end

node = nil

synchronize do
Expand Down Expand Up @@ -480,14 +488,15 @@ def highest_ranked_stems(doc, count = 3)
# Custom marshal serialization to exclude mutex state
# @rbs () -> Array[untyped]
def marshal_dump
[@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty]
[@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty, @min_word_length]
end

# Custom marshal deserialization to recreate mutex
# @rbs (Array[untyped]) -> void
def marshal_load(data)
mu_initialize
@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty = data
@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty,
@min_word_length = data
@storage = nil
end

Expand Down
22 changes: 15 additions & 7 deletions lib/classifier/tfidf.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TFIDF
# @rbs @fitted: bool
# @rbs @dirty: bool
# @rbs @storage: Storage::Base?
# @rbs @min_word_length: Integer

attr_reader :vocabulary, :idf, :num_documents
attr_accessor :storage
Expand All @@ -36,10 +37,12 @@ class TFIDF
# - min_df/max_df: filter terms by document frequency (Integer for count, Float for proportion)
# - ngram_range: [1,1] for unigrams, [1,2] for unigrams+bigrams
# - sublinear_tf: use 1 + log(tf) instead of raw term frequency
# - min_word_length: minimum word length filter in tokenization
#
# @rbs (?min_df: Integer | Float, ?max_df: Integer | Float,
# ?ngram_range: Array[Integer], ?sublinear_tf: bool) -> void
def initialize(min_df: 1, max_df: 1.0, ngram_range: [1, 1], sublinear_tf: false)
# ?ngram_range: Array[Integer], ?sublinear_tf: bool, ?min_word_length: Integer) -> void
def initialize(min_df: 1, max_df: 1.0, ngram_range: [1, 1], sublinear_tf: false,
min_word_length: Classifier.config.min_word_length)
validate_df!(min_df, 'min_df')
validate_df!(max_df, 'max_df')
validate_ngram_range!(ngram_range)
Expand All @@ -54,6 +57,7 @@ def initialize(min_df: 1, max_df: 1.0, ngram_range: [1, 1], sublinear_tf: false)
@fitted = false
@dirty = false
@storage = nil
@min_word_length = min_word_length
end

# Learns vocabulary and IDF weights from the corpus.
Expand Down Expand Up @@ -204,7 +208,8 @@ def as_json(_options = nil)
vocabulary: @vocabulary,
idf: @idf,
num_documents: @num_documents,
fitted: @fitted
fitted: @fitted,
min_word_length: @min_word_length
}
end

Expand All @@ -223,7 +228,8 @@ def self.from_json(json)
min_df: data['min_df'],
max_df: data['max_df'],
ngram_range: data['ngram_range'],
sublinear_tf: data['sublinear_tf']
sublinear_tf: data['sublinear_tf'],
min_word_length: data['min_word_length'] || Classifier.config.min_word_length
)

instance.instance_variable_set(:@vocabulary, symbolize_keys(data['vocabulary']))
Expand All @@ -238,12 +244,14 @@ def self.from_json(json)

# @rbs () -> Array[untyped]
def marshal_dump
[@min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted]
[@min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted,
@min_word_length]
end

# @rbs (Array[untyped]) -> void
def marshal_load(data)
@min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted = data
@min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted,
@min_word_length = data
@dirty = false
@storage = nil
end
Expand Down Expand Up @@ -334,7 +342,7 @@ def extract_terms(document)
result = Hash.new(0)

if @ngram_range[0] <= 1
word_hash = document.clean_word_hash
word_hash = document.clean_word_hash(@min_word_length)
word_hash.each { |term, count| result[term] += count }
end

Expand Down
Loading
Loading