summaryrefslogtreecommitdiffstats
path: root/tryprune.py
diff options
context:
space:
mode:
authorPeng Wu <alexepico@gmail.com>2011-07-25 11:54:05 +0800
committerPeng Wu <alexepico@gmail.com>2011-07-25 11:54:05 +0800
commit9cca4180400774231f64363c0f6acb87e5444f23 (patch)
tree02cd6c71e728157d87a54fa914678db249ebaa05 /tryprune.py
parent114342747eaf4e5719745c0cd113272e3c4c51ba (diff)
downloadtrainer-9cca4180400774231f64363c0f6acb87e5444f23.tar.gz
trainer-9cca4180400774231f64363c0f6acb87e5444f23.tar.xz
trainer-9cca4180400774231f64363c0f6acb87e5444f23.zip
write tryprune.py in progress
Diffstat (limited to 'tryprune.py')
-rw-r--r--tryprune.py58
1 files changed, 55 insertions, 3 deletions
diff --git a/tryprune.py b/tryprune.py
index ee0acbd..f86bd48 100644
--- a/tryprune.py
+++ b/tryprune.py
@@ -1,5 +1,7 @@
#!/usr/bin/python3
import os
+import os.path
+import shutil
import sys
from subprocess import Popen, PIPE
from argparse import ArgumentParser
@@ -45,7 +47,17 @@ def exportModel(modelfile, textmodel):
sys.exit('Corrupted model found when exporting:' + modelfile)
#end processing
-def mergeOneModel(mergedmodel, onemodel):
+def mergeOneModel(mergedmodel, onemodel, score):
+ #validate first
+ validateModel(onemodel)
+
+ onemodelstatuspath = onemodel + config.getStatusPostfix()
+ onemodelstatus = utils.load_status(onemodelstatuspath)
+ if not utils.check_epoch(onemodelstatus, 'Estimate'):
+ raise utils.Epoch('Please estimate first.\n')
+ if score != onemodelstatus['EstimateScore']:
+ raise AssertionError('estimate scores mis-match.\n')
+
#begin processing
cmdline = ['./merge_k_mixture_model', \
'--result-file', \
@@ -59,8 +71,28 @@ def mergeOneModel(mergedmodel, onemodel):
sys.exit('Corrupted model found when merging:' + onemodel)
#end processing
-def mergeSomeModels(indexfile, mergenum):
- pass
+def mergeSomeModels(tryname, mergedmodel, sortedindexname, mergenum):
+ last_score = 1.
+ #begin processing
+ indexfile = open(sortedindexname, 'r')
+ for i in range(mergenum):
+ line = indexfile.readline()
+ if not line:
+ raise AssertionError('No more models.\n')
+ line = line.rstrip(os.linesep)
+ (subdir, modelname, score) = line.split('#', 2)
+ score = float(score)
+ if score > last_score:
+ raise AssertionError('score must be descending.\n')
+
+ onemodel = os.path.join(config.getModelDir(), subdir, modelname)
+ mergeOneModel(mergedmodel, onemodel, score)
+ last_score = score
+ indexfile.close()
+ #end processing
+
+ #validate merged model
+ validateModel(mergedmodel)
def pruneModel(modelfile, k, CDF):
#begin processing
@@ -98,3 +130,23 @@ if __name__ == '__main__':
args = parser.parse_args()
print(args)
+ tryname = 'try' + args.tryname
+ #merge model candidates
+ mergedmodel = os.path.join(config.getFinalDir(), tryname, 'merged.db')
+ sortedindexname = os.path.join(args.modeldir, \
+ config.getSortedEstimateIndex())
+ mergeSomeModels(tryname, mergedmodel, sortedindexname, args.mergenumber)
+
+ #export textual format
+ exportfile = os.path.join(config.getFinalDir(), tryname, 'kmm_merged.text')
+ exportModel(mergedmodel, exportfile)
+
+ #prune merged model
+ prunedmodel = os.path.join(config.getFinalDir(), tryname, 'pruned.db')
+ #backup merged model
+ shutil.copyfile(mergedmodel, prunedmodel)
+ pruneModel(prunedmodel, args.k, args.CDF)
+
+ #export textual format
+ exportfile = os.path.join(config.getFinalDir(), tryname, 'kmm_pruned.text')
+ exportModel(prunedmodel, exportModel)