diff --git a/lib/classifier.rb b/lib/classifier.rb index 369dc646..26a5488b 100644 --- a/lib/classifier.rb +++ b/lib/classifier.rb @@ -36,3 +36,4 @@ require 'classifier/knn' require 'classifier/tfidf' require 'classifier/logistic_regression' +require 'classifier/config' diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index 394880c4..09a9a628 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -20,6 +20,7 @@ class Bayes # rubocop:disable Metrics/ClassLength # @rbs @cached_vocab_size: Integer? # @rbs @dirty: bool # @rbs @storage: Storage::Base? + # @rbs @min_word_length: Integer attr_accessor :storage @@ -27,8 +28,9 @@ class Bayes # rubocop:disable Metrics/ClassLength # initialized and given a training method. E.g., # b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam' # b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam'] - # @rbs (*String | Symbol | Array[String | Symbol]) -> void - def initialize(*categories) + # b = Classifier::Bayes.new 'Spam', min_word_length: 1 + # @rbs (*String | Symbol | Array[String | Symbol], ?min_word_length: Integer) -> void + def initialize(*categories, min_word_length: Classifier.config.min_word_length) super() @categories = {} categories.flatten.each { |category| @categories[category.prepare_category_name] = {} } @@ -39,6 +41,7 @@ def initialize(*categories) @cached_vocab_size = nil @dirty = false @storage = nil + @min_word_length = min_word_length end # Trains the classifier with text for a category. @@ -76,7 +79,7 @@ def untrain(category = nil, text = nil, **categories) # # @rbs (String) -> Hash[String, Float] def classifications(text) - words = text.word_hash.keys + words = text.word_hash(@min_word_length).keys synchronize do training_count = cached_training_count vocab_size = cached_vocab_size @@ -117,7 +120,8 @@ def as_json(_options = nil) categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) }, total_words: @total_words, category_counts: @category_counts.transform_keys(&:to_s), - category_word_count: @category_word_count.transform_keys(&:to_s) + category_word_count: @category_word_count.transform_keys(&:to_s), + min_word_length: @min_word_length } end @@ -409,7 +413,7 @@ def train_batch_internal(category, batch) invalidate_caches @dirty = true batch.each do |text| - word_hash = text.word_hash + word_hash = text.word_hash(@min_word_length) @category_counts[category] += 1 word_hash.each do |word, count| @categories[category][word] ||= 0 @@ -425,7 +429,7 @@ def train_batch_internal(category, batch) # @rbs (String | Symbol, String) -> void def train_single(category, text) category = category.prepare_category_name - word_hash = text.word_hash + word_hash = text.word_hash(@min_word_length) synchronize do invalidate_caches @dirty = true @@ -443,7 +447,7 @@ def train_single(category, text) # @rbs (String | Symbol, String) -> void def untrain_single(category, text) category = category.prepare_category_name - word_hash = text.word_hash + word_hash = text.word_hash(@min_word_length) synchronize do invalidate_caches @dirty = true @@ -487,6 +491,7 @@ def restore_state(data) @cached_vocab_size = nil @dirty = false @storage = nil + @min_word_length = data['min_word_length'] || Classifier.config.min_word_length data['categories'].each do |cat_name, words| @categories[cat_name.to_sym] = words.transform_keys(&:to_sym) diff --git a/lib/classifier/config.rb b/lib/classifier/config.rb new file mode 100644 index 00000000..9cb070bb --- /dev/null +++ b/lib/classifier/config.rb @@ -0,0 +1,31 @@ +# rbs_inline: enabled + +module Classifier + # @rbs @config: Config? + + # This lazy initialization is not thread-safe. + # In multi-threaded environments, ensure this method is called + # or configuration is set explicitly during startup before using classifiers. + # @rbs () -> Config + def config + @config ||= Config.new + end + + # @rbs () { (Config) -> void } -> void + def configure(&block) + block&.call(config) + end + + module_function :config, :configure + + class Config + # @rbs @min_word_length: Integer + + attr_accessor :min_word_length #: Integer + + # @rbs () -> void + def initialize + @min_word_length = 3 + end + end +end diff --git a/lib/classifier/extensions/word_hash.rb b/lib/classifier/extensions/word_hash.rb index b6462da2..c51a8d8e 100644 --- a/lib/classifier/extensions/word_hash.rb +++ b/lib/classifier/extensions/word_hash.rb @@ -20,27 +20,27 @@ def without_punctuation # Return a Hash of strings => ints. Each word in the string is stemmed, # interned, and indexes to its frequency in the document. - # @rbs () -> Hash[Symbol, Integer] - def word_hash - word_hash = clean_word_hash + # @rbs (?Integer) -> Hash[Symbol, Integer] + def word_hash(min_word_length = 3) + word_hash = clean_word_hash(min_word_length) symbol_hash = word_hash_for_symbols(gsub(/\w/, ' ').split) word_hash.merge(symbol_hash) end # Return a word hash without extra punctuation or short symbols, just stemmed words - # @rbs () -> Hash[Symbol, Integer] - def clean_word_hash - word_hash_for_words gsub(/[^\w\s]/, '').split + # @rbs (?Integer) -> Hash[Symbol, Integer] + def clean_word_hash(min_word_length = 3) + word_hash_for_words(gsub(/[^\w\s]/, '').split, min_word_length) end private - # @rbs (Array[String]) -> Hash[Symbol, Integer] - def word_hash_for_words(words) + # @rbs (Array[String], Integer) -> Hash[Symbol, Integer] + def word_hash_for_words(words, min_word_length) d = Hash.new(0) words.each do |word| word.downcase! - d[word.stem.intern] += 1 if !CORPUS_SKIP_WORDS.include?(word) && word.length > 2 + d[word.stem.intern] += 1 if !CORPUS_SKIP_WORDS.include?(word) && word.length >= min_word_length end d end diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index a169f85e..72de3176 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -34,6 +34,7 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength # @rbs @fitted: bool # @rbs @dirty: bool # @rbs @storage: Storage::Base? + # @rbs @min_word_length: Integer attr_accessor :storage @@ -53,13 +54,16 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength # - regularization: L2 regularization strength (default: 0.01) # - max_iterations: Maximum training iterations (default: 100) # - tolerance: Convergence threshold (default: 1e-4) + # - min_word_length: Minimum word length filter in tokenization # # @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float, - # ?max_iterations: Integer, ?tolerance: Float) -> void + # ?max_iterations: Integer, ?tolerance: Float, ?min_word_length: Integer) -> void + # rubocop:disable Metrics/ParameterLists def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, regularization: DEFAULT_REGULARIZATION, max_iterations: DEFAULT_MAX_ITERATIONS, - tolerance: DEFAULT_TOLERANCE) + tolerance: DEFAULT_TOLERANCE, + min_word_length: Classifier.config.min_word_length) super() categories = categories.flatten @categories = categories.map { |c| c.to_s.prepare_category_name } @@ -74,7 +78,9 @@ def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, @fitted = false @dirty = false @storage = nil + @min_word_length = min_word_length end + # rubocop:enable Metrics/ParameterLists # Trains the classifier with text for a category. # @@ -130,7 +136,7 @@ def classify(text) def probabilities(text) raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted - features = text.word_hash + features = text.word_hash(@min_word_length) synchronize do softmax(compute_scores(features)) end @@ -143,7 +149,7 @@ def probabilities(text) def classifications(text) raise NotFittedError, 'Model not fitted. Call fit() after training.' unless @fitted - features = text.word_hash + features = text.word_hash(@min_word_length) synchronize do compute_scores(features).transform_keys(&:to_s) end @@ -239,7 +245,8 @@ def as_json(_options = nil) regularization: @regularization, max_iterations: @max_iterations, tolerance: @tolerance, - fitted: @fitted + fitted: @fitted, + min_word_length: @min_word_length } end @@ -336,7 +343,7 @@ def reload! def marshal_dump fit unless @fitted [@categories, @weights, @bias, @vocabulary, @learning_rate, @regularization, - @max_iterations, @tolerance, @fitted] + @max_iterations, @tolerance, @fitted, @min_word_length] end # Custom marshal deserialization to recreate mutex. @@ -345,7 +352,7 @@ def marshal_dump def marshal_load(data) mu_initialize @categories, @weights, @bias, @vocabulary, @learning_rate, @regularization, - @max_iterations, @tolerance, @fitted = data + @max_iterations, @tolerance, @fitted, @min_word_length = data @training_data = [] @dirty = false @storage = nil @@ -395,7 +402,7 @@ def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE) reader.each_batch do |batch| synchronize do batch.each do |text| - features = text.word_hash + features = text.word_hash(@min_word_length) features.each_key { |word| @vocabulary[word] = true } @training_data << { category: category, features: features } end @@ -444,7 +451,7 @@ def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT documents.each_slice(batch_size) do |batch| synchronize do batch.each do |text| - features = text.word_hash + features = text.word_hash(@min_word_length) features.each_key { |word| @vocabulary[word] = true } @training_data << { category: category, features: features } end @@ -463,7 +470,7 @@ def train_single(category, text) category = category.to_s.prepare_category_name raise StandardError, "No such category: #{category}" unless @categories.include?(category) - features = text.word_hash + features = text.word_hash(@min_word_length) synchronize do features.each_key { |word| @vocabulary[word] = true } @training_data << { category: category, features: features } @@ -570,6 +577,7 @@ def restore_state(data, categories) @fitted = data.fetch('fitted', true) @dirty = false @storage = nil + @min_word_length = data['min_word_length'] || Classifier.config.min_word_length end def restore_weights_and_bias(data) diff --git a/lib/classifier/lsi.rb b/lib/classifier/lsi.rb index 6e1a6620..5c277150 100644 --- a/lib/classifier/lsi.rb +++ b/lib/classifier/lsi.rb @@ -80,6 +80,7 @@ class LSI # @rbs @u_matrix: Matrix? # @rbs @max_rank: Integer # @rbs @initial_vocab_size: Integer? + # @rbs @min_word_length: Integer attr_reader :word_list, :singular_values attr_accessor :auto_rebuild, :storage @@ -110,6 +111,7 @@ def initialize(options = {}) @max_rank = options[:max_rank] || DEFAULT_MAX_RANK @u_matrix = nil @initial_vocab_size = nil + @min_word_length = options[:min_word_length] || Classifier.config.min_word_length end # Returns true if the index needs to be rebuilt. The index needs @@ -216,7 +218,13 @@ def add(**items) # # @rbs (String, *String | Symbol) ?{ (String) -> String } -> void def add_item(item, *categories, &block) - clean_word_hash = block ? block.call(item).clean_word_hash : item.to_s.clean_word_hash + clean_word_hash = + if block + block.call(item).clean_word_hash(@min_word_length) + else + item.to_s.clean_word_hash(@min_word_length) + end + node = nil synchronize do @@ -480,14 +488,15 @@ def highest_ranked_stems(doc, count = 3) # Custom marshal serialization to exclude mutex state # @rbs () -> Array[untyped] def marshal_dump - [@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty] + [@auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty, @min_word_length] end # Custom marshal deserialization to recreate mutex # @rbs (Array[untyped]) -> void def marshal_load(data) mu_initialize - @auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty = data + @auto_rebuild, @word_list, @items, @version, @built_at_version, @dirty, + @min_word_length = data @storage = nil end diff --git a/lib/classifier/tfidf.rb b/lib/classifier/tfidf.rb index 56c52b15..20a65e11 100644 --- a/lib/classifier/tfidf.rb +++ b/lib/classifier/tfidf.rb @@ -28,6 +28,7 @@ class TFIDF # @rbs @fitted: bool # @rbs @dirty: bool # @rbs @storage: Storage::Base? + # @rbs @min_word_length: Integer attr_reader :vocabulary, :idf, :num_documents attr_accessor :storage @@ -36,10 +37,12 @@ class TFIDF # - min_df/max_df: filter terms by document frequency (Integer for count, Float for proportion) # - ngram_range: [1,1] for unigrams, [1,2] for unigrams+bigrams # - sublinear_tf: use 1 + log(tf) instead of raw term frequency + # - min_word_length: minimum word length filter in tokenization # # @rbs (?min_df: Integer | Float, ?max_df: Integer | Float, - # ?ngram_range: Array[Integer], ?sublinear_tf: bool) -> void - def initialize(min_df: 1, max_df: 1.0, ngram_range: [1, 1], sublinear_tf: false) + # ?ngram_range: Array[Integer], ?sublinear_tf: bool, ?min_word_length: Integer) -> void + def initialize(min_df: 1, max_df: 1.0, ngram_range: [1, 1], sublinear_tf: false, + min_word_length: Classifier.config.min_word_length) validate_df!(min_df, 'min_df') validate_df!(max_df, 'max_df') validate_ngram_range!(ngram_range) @@ -54,6 +57,7 @@ def initialize(min_df: 1, max_df: 1.0, ngram_range: [1, 1], sublinear_tf: false) @fitted = false @dirty = false @storage = nil + @min_word_length = min_word_length end # Learns vocabulary and IDF weights from the corpus. @@ -204,7 +208,8 @@ def as_json(_options = nil) vocabulary: @vocabulary, idf: @idf, num_documents: @num_documents, - fitted: @fitted + fitted: @fitted, + min_word_length: @min_word_length } end @@ -223,7 +228,8 @@ def self.from_json(json) min_df: data['min_df'], max_df: data['max_df'], ngram_range: data['ngram_range'], - sublinear_tf: data['sublinear_tf'] + sublinear_tf: data['sublinear_tf'], + min_word_length: data['min_word_length'] || Classifier.config.min_word_length ) instance.instance_variable_set(:@vocabulary, symbolize_keys(data['vocabulary'])) @@ -238,12 +244,14 @@ def self.from_json(json) # @rbs () -> Array[untyped] def marshal_dump - [@min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted] + [@min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted, + @min_word_length] end # @rbs (Array[untyped]) -> void def marshal_load(data) - @min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted = data + @min_df, @max_df, @ngram_range, @sublinear_tf, @vocabulary, @idf, @num_documents, @fitted, + @min_word_length = data @dirty = false @storage = nil end @@ -334,7 +342,7 @@ def extract_terms(document) result = Hash.new(0) if @ngram_range[0] <= 1 - word_hash = document.clean_word_hash + word_hash = document.clean_word_hash(@min_word_length) word_hash.each { |term, count| result[term] += count } end diff --git a/test/bayes/bayesian_test.rb b/test/bayes/bayesian_test.rb index 43838a7d..f59d1f68 100644 --- a/test/bayes/bayesian_test.rb +++ b/test/bayes/bayesian_test.rb @@ -28,6 +28,17 @@ def test_array_initialization assert_equal 'Spam', classifier.classify('this is spam') end + def test_initialization_with_min_word_length + classifier = Classifier::Bayes.new(%w[Spam Ham], min_word_length: 5) + + classifier.train_spam 'bad nasty spam email' + classifier.train_ham 'good legitimate email' + + assert_equal 'Spam', classifier.classify('nasty text') + assert_equal 'Ham', classifier.classify('legitimate text') + assert_equal 'Spam', classifier.classify('good text') + end + def test_add_category @classifier.add_category 'Test' diff --git a/test/config/config_test.rb b/test/config/config_test.rb new file mode 100644 index 00000000..556d0942 --- /dev/null +++ b/test/config/config_test.rb @@ -0,0 +1,22 @@ +require_relative '../test_helper' +require 'classifier/config' + +class ConfigTest < Minitest::Test + def teardown + Classifier.config.min_word_length = 3 + end + + def test_configure + Classifier.configure do |config| + config.min_word_length = 1 + end + + assert_equal(1, Classifier.config.min_word_length) + end + + def test_default + config = Classifier::Config.new + + assert_equal(3, config.min_word_length) + end +end diff --git a/test/logistic_regression/logistic_regression_test.rb b/test/logistic_regression/logistic_regression_test.rb index 5b3c0893..73e0bf41 100644 --- a/test/logistic_regression/logistic_regression_test.rb +++ b/test/logistic_regression/logistic_regression_test.rb @@ -56,6 +56,18 @@ def test_custom_hyperparameters assert_instance_of Classifier::LogisticRegression, classifier end + def test_initialization_with_min_word_length + classifier = Classifier::LogisticRegression.new(%w[Spam Ham], min_word_length: 5) + + classifier.train_spam 'bad nasty spam email' + classifier.train_ham 'good legitimate email' + classifier.fit + + assert_equal 'Spam', classifier.classify('nasty text') + assert_equal 'Ham', classifier.classify('legitimate text') + assert_in_delta 0.5, classifier.probabilities('good text')['Spam'], 0.02 + end + def test_categories assert_equal %w[Spam Ham].sort, @classifier.categories.sort end diff --git a/test/lsi/lsi_test.rb b/test/lsi/lsi_test.rb index dbb0712f..9c8f92a9 100644 --- a/test/lsi/lsi_test.rb +++ b/test/lsi/lsi_test.rb @@ -54,6 +54,16 @@ def test_add_batch_operations assert_equal ['Cat'], lsi.categories_for('Cats are independent') end + def test_custom_min_word_length + lsi = Classifier::LSI.new(min_word_length: 5) + lsi.add( + 'Dog' => ['Dogs are loyal', 'Puppies are cute'], + 'Cat' => ['Cats are independent', 'Kittens are playful'] + ) + + assert_equal(['Dog'], lsi.categories_for('Puppies are cute')) + end + def test_add_classification_works lsi = Classifier::LSI.new lsi.add( diff --git a/test/tfidf/tfidf_test.rb b/test/tfidf/tfidf_test.rb index fde00e2f..23f35f45 100644 --- a/test/tfidf/tfidf_test.rb +++ b/test/tfidf/tfidf_test.rb @@ -93,6 +93,16 @@ def test_invalid_ngram_range_raises assert_raises(ArgumentError) { Classifier::TFIDF.new(ngram_range: 'invalid') } end + def test_custom_min_word_length + tfidf = Classifier::TFIDF.new(min_word_length: 5) + tfidf.fit(@corpus) + + v = tfidf.transform('Dogs are loyal') + + assert_equal(1, v.count) + assert_in_delta(1.0, v[:loyal]) + end + # Fit tests def test_fit_builds_vocabulary