From d4eb5445d98587c05be46747ab3da99d3df67a5a Mon Sep 17 00:00:00 2001 From: Peng Wu Date: Tue, 10 May 2011 13:31:36 +0800 Subject: write gen k mixture model in progress --- utils/training/gen_k_mixture_model.cpp | 112 ++++++++++++++++++++++++++++++++- 1 file changed, 110 insertions(+), 2 deletions(-) diff --git a/utils/training/gen_k_mixture_model.cpp b/utils/training/gen_k_mixture_model.cpp index 7d4e3ed..ba0936f 100644 --- a/utils/training/gen_k_mixture_model.cpp +++ b/utils/training/gen_k_mixture_model.cpp @@ -23,13 +23,17 @@ #include #include "pinyin.h" +#include "k_mixture_model.h" typedef GHashTable * HashofWordPair; typedef GHashTable * HashofSecondWord; /* Hash token of Hash token of word count. */ -HashofWordPair g_hash_of_document = NULL; -PhraseLargeTable * g_phrases = NULL; +static HashofWordPair g_hash_of_document = NULL; +static PhraseLargeTable * g_phrases = NULL; +static KMixtureModelBigram * g_bigram = NULL; +static guint32 g_maximum_occurs = 20; +static parameter_t g_maximum_increase_rates = 3.; void print_help(){ printf("gen_k_mixture_model [--skip-pi-gram-training]\n"); @@ -99,6 +103,110 @@ bool convert_document_to_hash(FILE * document){ return true; } +static void train_word_pair(gpointer key, gpointer value, + gpointer user_data){ + phrase_token_t token = GPOINTER_TO_UINT(key); + guint32 count = GPOINTER_TO_UINT(value); + KMixtureModelSingleGram * single_gram = + (KMixtureModelSingleGram *)user_data; + KMixtureModelArrayItem array_item; + guint32 delta = 0; + + bool exists = single_gram->get_array_item(token, array_item); + if ( exists ) { + guint32 maximum_occurs_allowed = std_lite::max + (g_maximum_occurs, + (guint32)ceil(array_item.m_Mr * g_maximum_increase_rates)); + /* Exceeds the maximum occurs allowed of the word or phrase, + * in a single document. + */ + if ( count > maximum_occurs_allowed ) + return; + array_item.m_WC += count; + /* array_item.m_T += count; the same as m_WC. */ + array_item.m_N_n_0 ++; + if ( 1 == count ) + array_item.m_n_1 ++; + array_item.m_Mr = std_lite::max(array_item.m_Mr, count); + delta = count; + } else { /* item doesn't exist. */ + /* the same as above. */ + if ( count > g_maximum_occurs ) + return; + memset(&array_item, 0, sizeof(KMixtureModelArrayItem)); + array_item.m_WC = count; + /* array_item.m_T = count; the same as m_WC. */ + array_item.m_N_n_0 = 1; + if ( 1 == count ) + array_item.m_n_1 = 1; + array_item.m_Mr = count; + delta = count; + } + /* save delta in the array header. */ + KMixtureModelArrayHeader array_header; + single_gram->get_array_header(array_header); + array_header.m_WC += delta; + single_gram->set_array_header(array_header); +} + +bool train_single_gram(phrase_token_t token, + KMixtureModelSingleGram * single_gram, + guint32 & delta){ + assert(NULL != single_gram); + delta = 0; /* delta in WC of single_gram. */ + KMixtureModelArrayHeader array_header; + assert(single_gram->get_array_header(array_header)); + guint32 saved_array_header_WC = array_header.m_WC; + + HashofSecondWord hash_of_second_word = NULL; + gpointer value = NULL; + assert(g_hash_table_lookup_extended + (g_hash_of_document, GUINT_TO_POINTER(token), + NULL, &value)); + hash_of_second_word = (HashofSecondWord) value; + assert(NULL != hash_of_second_word); + + g_hash_table_foreach(hash_of_second_word, train_word_pair, single_gram); + + assert(single_gram->get_array_header(array_header)); + delta = array_header.m_WC - saved_array_header_WC; + return true; +} + +static void train_single_gram_wrapper(gpointer key, gpointer value, + gpointer user_data){ + phrase_token_t token = GPOINTER_TO_UINT(key); + guint32 delta = 0; + + KMixtureModelSingleGram * single_gram = NULL; + bool exists = g_bigram->load(token, single_gram); + if ( exists ){ + train_single_gram(token, single_gram, delta); + } else { /* item doesn't exist. */ + single_gram = new KMixtureModelSingleGram; + train_single_gram(token, single_gram, delta); + } + + KMixtureModelMagicHeader magic_header; + assert(g_bigram->get_magic_header(magic_header)); + if ( magic_header.m_WC + delta < magic_header.m_WC ){ + fprintf(stderr, "the m_WC integer in magic header overflows.\n"); + return; + } + magic_header.m_WC += delta; + magic_header.m_N ++; + assert(g_bigram->set_magic_header(magic_header)); + + /* save the single gram. */ + assert(g_bigram->store(token, single_gram)); + delete single_gram; +} + +bool train_document(){ + g_hash_table_foreach(g_hash_of_document, train_single_gram_wrapper, NULL); + return true; +} + int main(int argc, char * argv[]){ g_hash_of_document = g_hash_table_new_full (g_int_hash, g_int_equal, NULL, (GDestroyNotify)g_hash_table_unref); -- cgit