summaryrefslogtreecommitdiffstats
path: root/tryprune.py
diff options
context:
space:
mode:
authorPeng Wu <alexepico@gmail.com>2011-07-28 11:11:39 +0800
committerPeng Wu <alexepico@gmail.com>2011-07-28 11:11:39 +0800
commit9907d860802ea03319ca283f0e7f4737765c5513 (patch)
tree396292afa2aae9b523f9a3a1a5bcf773a1912eb0 /tryprune.py
parent499b18a799dd14b3ce584be0848d648400727776 (diff)
downloadtrainer-9907d860802ea03319ca283f0e7f4737765c5513.tar.gz
trainer-9907d860802ea03319ca283f0e7f4737765c5513.tar.xz
trainer-9907d860802ea03319ca283f0e7f4737765c5513.zip
add more outputs to tryprune.py
Diffstat (limited to 'tryprune.py')
-rwxr-xr-xtryprune.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/tryprune.py b/tryprune.py
index 9a820e9..805f4dd 100755
--- a/tryprune.py
+++ b/tryprune.py
@@ -68,8 +68,6 @@ def convertModel(kmm_model, inter_model):
def mergeOneModel(mergedmodel, onemodel, score):
- #validate first
- validateModel(onemodel)
onemodelstatuspath = onemodel + config.getStatusPostfix()
onemodelstatus = utils.load_status(onemodelstatuspath)
@@ -107,14 +105,16 @@ def mergeSomeModels(mergedmodel, sortedindexname, mergenum):
raise AssertionError('scores must be descending.\n')
onemodel = os.path.join(config.getModelDir(), subdir, modelname)
+
+ #validate first
+ print('validating')
+ validateModel(onemodel)
+
mergeOneModel(mergedmodel, onemodel, score)
last_score = score
indexfile.close()
#end processing
- #validate merged model
- validateModel(mergedmodel)
-
def pruneModel(prunedmodel, k, CDF):
#begin processing
@@ -129,8 +129,6 @@ def pruneModel(prunedmodel, k, CDF):
sys.exit('Corrupted model found when pruning:' + modelfile)
#end processing
- #validate pruned model
- validateModel(prunedmodel)
if __name__ == '__main__':
parser = ArgumentParser(description='Try prune models.')
@@ -178,6 +176,10 @@ if __name__ == '__main__':
config.getSortedEstimateIndex())
mergeSomeModels(mergedmodel, sortedindexname, args.merge)
+ #validate merged model
+ print('validating')
+ validateModel(mergedmodel)
+
#export textual format
print('exporting')
exportfile = os.path.join(trydir, 'kmm_merged.text')
@@ -190,6 +192,10 @@ if __name__ == '__main__':
shutil.copyfile(mergedmodel, prunedmodel)
pruneModel(prunedmodel, args.k, args.CDF)
+ #validate pruned model
+ print('validating')
+ validateModel(prunedmodel)
+
#export textual format
print('exporting')
exportfile = os.path.join(trydir, 'kmm_pruned.text')