summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPeng Wu <pwu@redhat.com>2024-09-02 11:42:17 +0800
committerPeng Wu <pwu@redhat.com>2024-09-02 14:24:51 +0800
commit429f9fd8dd2966db9038a4194e5970ba22e8e5c0 (patch)
tree6367d7d4b7b735bb98d2be0446de854b52fce99f
parentd789eb1cd9185eac7f3f82f230b1ee4414aed68c (diff)
downloadtrainer-429f9fd8dd2966db9038a4194e5970ba22e8e5c0.tar.gz
trainer-429f9fd8dd2966db9038a4194e5970ba22e8e5c0.tar.xz
trainer-429f9fd8dd2966db9038a4194e5970ba22e8e5c0.zip
write genpunct.py
-rwxr-xr-xgenpunct.py89
1 files changed, 82 insertions, 7 deletions
diff --git a/genpunct.py b/genpunct.py
index 22dc5aa..f66b35d 100755
--- a/genpunct.py
+++ b/genpunct.py
@@ -2,6 +2,7 @@
import os
import os.path
from argparse import ArgumentParser
+from operator import itemgetter
import utils
from myconfig import MyConfig
from dirwalk import walkIndex
@@ -17,6 +18,8 @@ os.chdir(puncts_dir)
# The order is important
Punct_Search = ['……', '…', ',', '。', ';', '?', '!', ':', '“', '”', '、']
+all_punct_pairs = {}
+
############################################################
# Handle File #
############################################################
@@ -77,11 +80,10 @@ def handleOneText(infile, punct_pairs):
docfile.close()
-def prunePunctPairFromOneIndex(workdir, threshold):
+def prunePunctPair(workdir, threshold, infilename, outfilename):
punct_pairs = {}
#load the punct pairs from text files
- punctfile = os.path.join(workdir, \
- config.getPunctuationPerIndexFileName())
+ punctfile = os.path.join(workdir, infilename)
with open(punctfile, 'r') as f:
punct_pairs = eval(f.read())
@@ -98,8 +100,7 @@ def prunePunctPairFromOneIndex(workdir, threshold):
newpunctpairs[key] = newpuncts
#save the punct pairs to text files
- punctfile = os.path.join(workdir, \
- config.getPunctuationPruneIndexFileName())
+ punctfile = os.path.join(workdir, outfilename)
with open(punctfile, 'w') as f:
f.write(repr(newpunctpairs))
@@ -157,13 +158,72 @@ def handleOneIndex(indexpath, subdir, indexname):
loadPunctPairFromOneIndex(indexpath, workdir)
# Prune the pair in the current index
- prunePunctPairFromOneIndex(workdir, \
- config.getPunctuationPerIndexPruneThreshold())
+ prunePunctPair(workdir, \
+ config.getPunctuationPerIndexPruneThreshold(), \
+ config.getPunctuationPerIndexFileName(), \
+ config.getPunctuationPruneIndexFileName())
#sign epoch
utils.sign_epoch(indexstatus, 'Punctuation')
utils.store_status(indexstatuspath, indexstatus)
+def loadOnePrune(indexpath, subdir, indexname):
+ global all_punct_pairs
+ print(indexpath, subdir, indexname)
+
+ workdir = config.getGeneratePunctuationDir() + os.sep + \
+ subdir + os.sep + indexname
+ print(workdir)
+
+ # Load the word and punctuation pair
+ punct_pairs = {}
+ #load the punct pairs from text files
+ punctfile = os.path.join(workdir, \
+ config.getPunctuationPruneIndexFileName())
+ with open(punctfile, 'r') as f:
+ punct_pairs = eval(f.read())
+
+ #merge into all punct pairs
+ for key, puncts in punct_pairs.items():
+ if key not in all_punct_pairs:
+ all_punct_pairs[key] = puncts
+ continue
+
+ #combine the puncts
+ newpuncts = []
+ oldpuncts = all_punct_pairs[key]
+ keys = set()
+ for punct, freq in oldpuncts + puncts:
+ keys.add(punct)
+ for punctkey in keys:
+ #old freq
+ oldfreq = [freq for punct, freq in oldpuncts if punct == punctkey]
+ #print("old freq", oldfreq)
+ freq = sum(oldfreq)
+ #new freq
+ newfreq = [freq for punct, freq in puncts if punct == punctkey]
+ freq += sum(newfreq)
+ newpuncts.append([punctkey, freq])
+
+ all_punct_pairs[key] = newpuncts
+
+
+def exportAllPunctPairs(workdir, infilename, outfilename):
+ # Load the word and punctuation pair
+ punct_pairs = {}
+ #load the punct pairs from text files
+ punctfile = os.path.join(workdir, infilename)
+ with open(punctfile, 'r') as f:
+ punct_pairs = eval(f.read())
+
+ tablefile = open(outfilename, 'w')
+ for key, puncts in punct_pairs.items():
+ (token, word) = key
+ puncts.sort(key=itemgetter(1), reverse=True)
+ for punct, freq in puncts:
+ line = "{0} {1} {2} {3}".format(token, word, punct, freq)
+ tablefile.writelines([line, os.linesep])
+ tablefile.close()
if __name__ == '__main__':
parser = ArgumentParser(description='Generate punctuation.')
@@ -176,9 +236,24 @@ if __name__ == '__main__':
print(args)
walkIndex(handleOneIndex, args.indexdir)
# Merge the word and punctuation pairs in all the index
+ walkIndex(loadOnePrune, args.indexdir)
+
+ #save the punct pairs to text files
+ punctfile = os.path.join(config.getGeneratePunctuationDir(), \
+ config.getPunctuationAllIndexFileName())
+ with open(punctfile, 'w') as f:
+ f.write(repr(all_punct_pairs))
# Prune the pairs in all the index
+ prunePunctPair(config.getGeneratePunctuationDir(), \
+ config.getPunctuationAllIndexPruneThreshold(), \
+ config.getPunctuationAllIndexFileName(), \
+ config.getPunctuationPruneAllIndexFileName())
+
# Export all the remaining pairs
+ exportAllPunctPairs(config.getGeneratePunctuationDir(), \
+ config.getPunctuationPruneAllIndexFileName(), \
+ config.getPunctuationTextFileName())
print('done')