summaryrefslogtreecommitdiffstats
path: root/utils/training/gen_k_mixture_model.cpp
diff options
context:
space:
mode:
authorPeng Wu <alexepico@gmail.com>2011-05-10 13:31:36 +0800
committerPeng Wu <alexepico@gmail.com>2011-05-10 13:31:36 +0800
commitd4eb5445d98587c05be46747ab3da99d3df67a5a (patch)
tree937ddc4a3f5237e52c33a7f527cca62aac02f434 /utils/training/gen_k_mixture_model.cpp
parentc7c8dda99cdf334803b72da07d8389e19cf365d8 (diff)
downloadlibpinyin-d4eb5445d98587c05be46747ab3da99d3df67a5a.tar.gz
libpinyin-d4eb5445d98587c05be46747ab3da99d3df67a5a.tar.xz
libpinyin-d4eb5445d98587c05be46747ab3da99d3df67a5a.zip
write gen k mixture model in progress
Diffstat (limited to 'utils/training/gen_k_mixture_model.cpp')
-rw-r--r--utils/training/gen_k_mixture_model.cpp112
1 files changed, 110 insertions, 2 deletions
diff --git a/utils/training/gen_k_mixture_model.cpp b/utils/training/gen_k_mixture_model.cpp
index 7d4e3ed..ba0936f 100644
--- a/utils/training/gen_k_mixture_model.cpp
+++ b/utils/training/gen_k_mixture_model.cpp
@@ -23,13 +23,17 @@
#include <glib.h>
#include "pinyin.h"
+#include "k_mixture_model.h"
typedef GHashTable * HashofWordPair;
typedef GHashTable * HashofSecondWord;
/* Hash token of Hash token of word count. */
-HashofWordPair g_hash_of_document = NULL;
-PhraseLargeTable * g_phrases = NULL;
+static HashofWordPair g_hash_of_document = NULL;
+static PhraseLargeTable * g_phrases = NULL;
+static KMixtureModelBigram * g_bigram = NULL;
+static guint32 g_maximum_occurs = 20;
+static parameter_t g_maximum_increase_rates = 3.;
void print_help(){
printf("gen_k_mixture_model [--skip-pi-gram-training]\n");
@@ -99,6 +103,110 @@ bool convert_document_to_hash(FILE * document){
return true;
}
+static void train_word_pair(gpointer key, gpointer value,
+ gpointer user_data){
+ phrase_token_t token = GPOINTER_TO_UINT(key);
+ guint32 count = GPOINTER_TO_UINT(value);
+ KMixtureModelSingleGram * single_gram =
+ (KMixtureModelSingleGram *)user_data;
+ KMixtureModelArrayItem array_item;
+ guint32 delta = 0;
+
+ bool exists = single_gram->get_array_item(token, array_item);
+ if ( exists ) {
+ guint32 maximum_occurs_allowed = std_lite::max
+ (g_maximum_occurs,
+ (guint32)ceil(array_item.m_Mr * g_maximum_increase_rates));
+ /* Exceeds the maximum occurs allowed of the word or phrase,
+ * in a single document.
+ */
+ if ( count > maximum_occurs_allowed )
+ return;
+ array_item.m_WC += count;
+ /* array_item.m_T += count; the same as m_WC. */
+ array_item.m_N_n_0 ++;
+ if ( 1 == count )
+ array_item.m_n_1 ++;
+ array_item.m_Mr = std_lite::max(array_item.m_Mr, count);
+ delta = count;
+ } else { /* item doesn't exist. */
+ /* the same as above. */
+ if ( count > g_maximum_occurs )
+ return;
+ memset(&array_item, 0, sizeof(KMixtureModelArrayItem));
+ array_item.m_WC = count;
+ /* array_item.m_T = count; the same as m_WC. */
+ array_item.m_N_n_0 = 1;
+ if ( 1 == count )
+ array_item.m_n_1 = 1;
+ array_item.m_Mr = count;
+ delta = count;
+ }
+ /* save delta in the array header. */
+ KMixtureModelArrayHeader array_header;
+ single_gram->get_array_header(array_header);
+ array_header.m_WC += delta;
+ single_gram->set_array_header(array_header);
+}
+
+bool train_single_gram(phrase_token_t token,
+ KMixtureModelSingleGram * single_gram,
+ guint32 & delta){
+ assert(NULL != single_gram);
+ delta = 0; /* delta in WC of single_gram. */
+ KMixtureModelArrayHeader array_header;
+ assert(single_gram->get_array_header(array_header));
+ guint32 saved_array_header_WC = array_header.m_WC;
+
+ HashofSecondWord hash_of_second_word = NULL;
+ gpointer value = NULL;
+ assert(g_hash_table_lookup_extended
+ (g_hash_of_document, GUINT_TO_POINTER(token),
+ NULL, &value));
+ hash_of_second_word = (HashofSecondWord) value;
+ assert(NULL != hash_of_second_word);
+
+ g_hash_table_foreach(hash_of_second_word, train_word_pair, single_gram);
+
+ assert(single_gram->get_array_header(array_header));
+ delta = array_header.m_WC - saved_array_header_WC;
+ return true;
+}
+
+static void train_single_gram_wrapper(gpointer key, gpointer value,
+ gpointer user_data){
+ phrase_token_t token = GPOINTER_TO_UINT(key);
+ guint32 delta = 0;
+
+ KMixtureModelSingleGram * single_gram = NULL;
+ bool exists = g_bigram->load(token, single_gram);
+ if ( exists ){
+ train_single_gram(token, single_gram, delta);
+ } else { /* item doesn't exist. */
+ single_gram = new KMixtureModelSingleGram;
+ train_single_gram(token, single_gram, delta);
+ }
+
+ KMixtureModelMagicHeader magic_header;
+ assert(g_bigram->get_magic_header(magic_header));
+ if ( magic_header.m_WC + delta < magic_header.m_WC ){
+ fprintf(stderr, "the m_WC integer in magic header overflows.\n");
+ return;
+ }
+ magic_header.m_WC += delta;
+ magic_header.m_N ++;
+ assert(g_bigram->set_magic_header(magic_header));
+
+ /* save the single gram. */
+ assert(g_bigram->store(token, single_gram));
+ delete single_gram;
+}
+
+bool train_document(){
+ g_hash_table_foreach(g_hash_of_document, train_single_gram_wrapper, NULL);
+ return true;
+}
+
int main(int argc, char * argv[]){
g_hash_of_document = g_hash_table_new_full
(g_int_hash, g_int_equal, NULL, (GDestroyNotify)g_hash_table_unref);