summaryrefslogtreecommitdiffstats
path: root/utils/training/gen_k_mixture_model.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'utils/training/gen_k_mixture_model.cpp')
-rw-r--r--utils/training/gen_k_mixture_model.cpp411
1 files changed, 411 insertions, 0 deletions
diff --git a/utils/training/gen_k_mixture_model.cpp b/utils/training/gen_k_mixture_model.cpp
new file mode 100644
index 0000000..2dfb3d1
--- /dev/null
+++ b/utils/training/gen_k_mixture_model.cpp
@@ -0,0 +1,411 @@
+/*
+ * libpinyin
+ * Library to deal with pinyin.
+ *
+ * Copyright (C) 2011 Peng Wu <alexepico@gmail.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+ */
+
+
+#include <glib.h>
+#include <locale.h>
+#include "pinyin_internal.h"
+#include "utils_helper.h"
+#include "k_mixture_model.h"
+
+/* Hash token of Hash token of word count. */
+typedef GHashTable * HashofDocument;
+typedef GHashTable * HashofSecondWord;
+
+typedef GHashTable * HashofUnigram;
+
+
+void print_help(){
+ printf("Usage: gen_k_mixture_model [--skip-pi-gram-training]\n"
+ " [--maximum-occurs-allowed <INT>]\n"
+ " [--maximum-increase-rates-allowed <FLOAT>]\n"
+ " [--k-mixture-model-file <FILENAME>]\n"
+ " {<FILENAME>}+\n");
+}
+
+
+static gint g_maximum_occurs = 20;
+static parameter_t g_maximum_increase_rates = 3.;
+static gboolean g_train_pi_gram = TRUE;
+static const gchar * g_k_mixture_model_filename = NULL;
+
+static GOptionEntry entries[] =
+{
+ {"skip-pi-gram-training", 0, G_OPTION_FLAG_REVERSE, G_OPTION_ARG_NONE, &g_train_pi_gram, "skip pi-gram training", NULL},
+ {"maximum-occurs-allowed", 0, 0, G_OPTION_ARG_INT, &g_maximum_occurs, "maximum occurs allowed", NULL},
+ {"maximum-increase-rates-allowed", 0, 0, G_OPTION_ARG_DOUBLE, &g_maximum_increase_rates, "maximum increase rates allowed", NULL},
+ {"k-mixture-model-file", 0, 0, G_OPTION_ARG_FILENAME, &g_k_mixture_model_filename, "k mixture model file", NULL},
+ {NULL}
+};
+
+
+bool read_document(PhraseLargeTable2 * phrase_table,
+ FacadePhraseIndex * phrase_index,
+ FILE * document,
+ HashofDocument hash_of_document,
+ HashofUnigram hash_of_unigram){
+
+ char * linebuf = NULL;size_t size = 0;
+ phrase_token_t last_token, cur_token = last_token = 0;
+
+ while ( getline(&linebuf, &size, document) ){
+ if ( feof(document) )
+ break;
+
+ if ( '\n' == linebuf[strlen(linebuf) - 1] ) {
+ linebuf[strlen(linebuf) - 1] = '\0';
+ }
+
+ TAGLIB_PARSE_SEGMENTED_LINE(phrase_index, token, linebuf);
+
+ last_token = cur_token;
+ cur_token = token;
+
+ /* skip null_token in second word. */
+ if ( null_token == cur_token )
+ continue;
+
+ gpointer value = NULL;
+ gboolean lookup_result = g_hash_table_lookup_extended
+ (hash_of_unigram, GUINT_TO_POINTER(cur_token),
+ NULL, &value);
+ if ( !lookup_result ){
+ g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(cur_token),
+ GUINT_TO_POINTER(1));
+ } else {
+ guint32 freq = GPOINTER_TO_UINT(value);
+ freq ++;
+ g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(cur_token),
+ GUINT_TO_POINTER(freq));
+ }
+
+ /* skip pi-gram training. */
+ if ( null_token == last_token ){
+ if ( !g_train_pi_gram )
+ continue;
+ last_token = sentence_start;
+ }
+
+ /* remember the (last_token, cur_token) word pair. */
+ HashofSecondWord hash_of_second_word = NULL;
+ lookup_result = g_hash_table_lookup_extended
+ (hash_of_document, GUINT_TO_POINTER(last_token),
+ NULL, &value);
+ if ( !lookup_result ){
+ hash_of_second_word = g_hash_table_new
+ (g_direct_hash, g_direct_equal);
+ } else {
+ hash_of_second_word = (HashofSecondWord) value;
+ }
+
+ value = NULL;
+ lookup_result = g_hash_table_lookup_extended
+ (hash_of_second_word, GUINT_TO_POINTER(cur_token),
+ NULL, &value);
+ guint32 count = 0;
+ if ( lookup_result ) {
+ count = GPOINTER_TO_UINT(value);
+ }
+ count ++;
+ g_hash_table_insert(hash_of_second_word,
+ GUINT_TO_POINTER(cur_token),
+ GUINT_TO_POINTER(count));
+ g_hash_table_insert(hash_of_document,
+ GUINT_TO_POINTER(last_token),
+ hash_of_second_word);
+ }
+
+ free(linebuf);
+
+ return true;
+}
+
+static void train_word_pair(HashofUnigram hash_of_unigram,
+ KMixtureModelSingleGram * single_gram,
+ phrase_token_t token2, guint32 count){
+ KMixtureModelArrayItem array_item;
+
+ bool exists = single_gram->get_array_item(token2, array_item);
+ if ( exists ) {
+ guint32 maximum_occurs_allowed = std_lite::max
+ ((guint32)g_maximum_occurs,
+ (guint32)ceil(array_item.m_Mr * g_maximum_increase_rates));
+ /* Exceeds the maximum occurs allowed of the word or phrase,
+ * in a single document.
+ */
+ if ( count > maximum_occurs_allowed ){
+ gpointer value = NULL;
+ assert( g_hash_table_lookup_extended
+ (hash_of_unigram, GUINT_TO_POINTER(token2),
+ NULL, &value) );
+ guint32 freq = GPOINTER_TO_UINT(value);
+ freq -= count;
+ if ( freq > 0 ) {
+ g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(token2),
+ GUINT_TO_POINTER(freq));
+ } else if ( freq == 0 ) {
+ assert(g_hash_table_steal(hash_of_unigram,
+ GUINT_TO_POINTER(token2)));
+ } else {
+ assert(false);
+ }
+ return;
+ }
+ array_item.m_WC += count;
+ /* array_item.m_T += count; the same as m_WC. */
+ array_item.m_N_n_0 ++;
+ if ( 1 == count )
+ array_item.m_n_1 ++;
+ array_item.m_Mr = std_lite::max(array_item.m_Mr, count);
+ assert(single_gram->set_array_item(token2, array_item));
+ } else { /* item doesn't exist. */
+ /* the same as above. */
+ if ( count > g_maximum_occurs ){
+ gpointer value = NULL;
+ assert( g_hash_table_lookup_extended
+ (hash_of_unigram, GUINT_TO_POINTER(token2),
+ NULL, &value) );
+ guint32 freq = GPOINTER_TO_UINT(value);
+ freq -= count;
+ if ( freq > 0 ) {
+ g_hash_table_insert(hash_of_unigram, GUINT_TO_POINTER(token2),
+ GUINT_TO_POINTER(freq));
+ } else if ( freq == 0 ) {
+ assert(g_hash_table_steal(hash_of_unigram,
+ GUINT_TO_POINTER(token2)));
+ } else {
+ assert(false);
+ }
+ return;
+ }
+ memset(&array_item, 0, sizeof(KMixtureModelArrayItem));
+ array_item.m_WC = count;
+ /* array_item.m_T = count; the same as m_WC. */
+ array_item.m_N_n_0 = 1;
+ if ( 1 == count )
+ array_item.m_n_1 = 1;
+ array_item.m_Mr = count;
+ assert(single_gram->insert_array_item(token2, array_item));
+ }
+
+ /* save delta in the array header. */
+ KMixtureModelArrayHeader array_header;
+ single_gram->get_array_header(array_header);
+ array_header.m_WC += count;
+ single_gram->set_array_header(array_header);
+}
+
+bool train_single_gram(HashofUnigram hash_of_unigram,
+ HashofDocument hash_of_document,
+ KMixtureModelSingleGram * single_gram,
+ phrase_token_t token1,
+ guint32 & delta){
+ assert(NULL != single_gram);
+ delta = 0; /* delta in WC of single_gram. */
+ KMixtureModelArrayHeader array_header;
+ assert(single_gram->get_array_header(array_header));
+ guint32 saved_array_header_WC = array_header.m_WC;
+
+ HashofSecondWord hash_of_second_word = NULL;
+ gpointer key, value = NULL;
+ assert(g_hash_table_lookup_extended
+ (hash_of_document, GUINT_TO_POINTER(token1),
+ NULL, &value));
+ hash_of_second_word = (HashofSecondWord) value;
+ assert(NULL != hash_of_second_word);
+
+ /* train word pair */
+ GHashTableIter iter;
+ g_hash_table_iter_init(&iter, hash_of_second_word);
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ phrase_token_t token2 = GPOINTER_TO_UINT(key);
+ guint32 count = GPOINTER_TO_UINT(value);
+ train_word_pair(hash_of_unigram, single_gram, token2, count);
+ }
+
+ assert(single_gram->get_array_header(array_header));
+ delta = array_header.m_WC - saved_array_header_WC;
+ return true;
+}
+
+static bool train_second_word(HashofUnigram hash_of_unigram,
+ KMixtureModelBigram * bigram,
+ HashofDocument hash_of_document,
+ phrase_token_t token1){
+ guint32 delta = 0;
+
+ KMixtureModelSingleGram * single_gram = NULL;
+ bool exists = bigram->load(token1, single_gram);
+ if ( !exists )
+ single_gram = new KMixtureModelSingleGram;
+ train_single_gram(hash_of_unigram, hash_of_document,
+ single_gram, token1, delta);
+
+ if ( 0 == delta ){ /* Please consider maximum occurs allowed. */
+ delete single_gram;
+ return false;
+ }
+
+ /* save the single gram. */
+ assert(bigram->store(token1, single_gram));
+ delete single_gram;
+
+ KMixtureModelMagicHeader magic_header;
+ if (!bigram->get_magic_header(magic_header)){
+ /* the first time to access the new k mixture model file. */
+ memset(&magic_header, 0, sizeof(KMixtureModelMagicHeader));
+ }
+
+ if ( magic_header.m_WC + delta < magic_header.m_WC ){
+ fprintf(stderr, "the m_WC integer in magic header overflows.\n");
+ return false;
+ }
+ magic_header.m_WC += delta;
+ assert(bigram->set_magic_header(magic_header));
+
+ return true;
+}
+
+/* Note: this method is a post-processing method, run this last. */
+static bool post_processing_unigram(KMixtureModelBigram * bigram,
+ HashofUnigram hash_of_unigram){
+ GHashTableIter iter;
+ gpointer key, value;
+ guint32 total_freq = 0;
+
+ g_hash_table_iter_init(&iter, hash_of_unigram);
+ while (g_hash_table_iter_next(&iter, &key, &value)){
+ guint32 token = GPOINTER_TO_UINT(key);
+ guint32 freq = GPOINTER_TO_UINT(value);
+ KMixtureModelArrayHeader array_header;
+ bool result = bigram->get_array_header(token, array_header);
+ array_header.m_freq += freq;
+ total_freq += freq;
+ bigram->set_array_header(token, array_header);
+ }
+
+ KMixtureModelMagicHeader magic_header;
+ assert(bigram->get_magic_header(magic_header));
+ if ( magic_header.m_total_freq + total_freq < magic_header.m_total_freq ){
+ fprintf(stderr, "the m_total_freq in magic header overflows.\n");
+ return false;
+ }
+ magic_header.m_total_freq += total_freq;
+ assert(bigram->set_magic_header(magic_header));
+
+ return true;
+}
+
+int main(int argc, char * argv[]){
+ int i = 1;
+
+ setlocale(LC_ALL, "");
+
+ GError * error = NULL;
+ GOptionContext * context;
+
+ context = g_option_context_new("- generate k mixture model");
+ g_option_context_add_main_entries(context, entries, NULL);
+ if (!g_option_context_parse(context, &argc, &argv, &error)) {
+ g_print("option parsing failed:%s\n", error->message);
+ exit(EINVAL);
+ }
+
+ SystemTableInfo system_table_info;
+
+ bool retval = system_table_info.load(SYSTEM_TABLE_INFO);
+ if (!retval) {
+ fprintf(stderr, "load table.conf failed.\n");
+ exit(ENOENT);
+ }
+
+ PhraseLargeTable2 phrase_table;
+ MemoryChunk * chunk = new MemoryChunk;
+ chunk->load(SYSTEM_PHRASE_INDEX);
+ phrase_table.load(chunk);
+
+ FacadePhraseIndex phrase_index;
+
+ const pinyin_table_info_t * phrase_files =
+ system_table_info.get_table_info();
+
+ if (!load_phrase_index(phrase_files, &phrase_index))
+ exit(ENOENT);
+
+ KMixtureModelBigram bigram(K_MIXTURE_MODEL_MAGIC_NUMBER);
+ bigram.attach(g_k_mixture_model_filename, ATTACH_READWRITE|ATTACH_CREATE);
+
+ while ( i < argc ){
+ const char * filename = argv[i];
+ FILE * document = fopen(filename, "r");
+ if ( NULL == document ){
+ int err_saved = errno;
+ fprintf(stderr, "can't open file: %s.\n", filename);
+ fprintf(stderr, "error:%s.\n", strerror(err_saved));
+ exit(err_saved);
+ }
+
+ HashofDocument hash_of_document = g_hash_table_new
+ (g_direct_hash, g_direct_equal);
+ HashofUnigram hash_of_unigram = g_hash_table_new
+ (g_direct_hash, g_direct_equal);
+
+ assert(read_document(&phrase_table, &phrase_index, document,
+ hash_of_document, hash_of_unigram));
+ fclose(document);
+ document = NULL;
+
+ GHashTableIter iter;
+ gpointer key, value;
+
+ /* train the document, and convert it to k mixture model. */
+ g_hash_table_iter_init(&iter, hash_of_document);
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ phrase_token_t token1 = GPOINTER_TO_UINT(key);
+ train_second_word(hash_of_unigram, &bigram,
+ hash_of_document, token1);
+ }
+
+ KMixtureModelMagicHeader magic_header;
+ assert(bigram.get_magic_header(magic_header));
+ magic_header.m_N ++;
+ assert(bigram.set_magic_header(magic_header));
+
+ post_processing_unigram(&bigram, hash_of_unigram);
+
+ /* free resources of g_hash_of_document */
+ g_hash_table_iter_init(&iter, hash_of_document);
+ while (g_hash_table_iter_next(&iter, &key, &value)) {
+ HashofSecondWord second_word = (HashofSecondWord) value;
+ g_hash_table_iter_steal(&iter);
+ g_hash_table_unref(second_word);
+ }
+ g_hash_table_unref(hash_of_document);
+ hash_of_document = NULL;
+
+ g_hash_table_unref(hash_of_unigram);
+ hash_of_unigram = NULL;
+
+ ++i;
+ }
+
+ return 0;
+}