21
21
stderr = sys .stderr
22
22
stdout = sys .stdout
23
23
24
+ from tempfile import mktemp
24
25
import timblapi
25
26
import io
26
27
import os
@@ -59,15 +60,19 @@ def u(s, encoding = 'utf-8', errors='strict'):
59
60
60
61
61
62
class TimblClassifier (object ):
62
- def __init__ (self , fileprefix , timbloptions , format = "Tabbed" , dist = True , encoding = 'utf-8' , overwrite = True , flushthreshold = 10000 , threading = False , normalize = True , debug = False ):
63
+ def __init__ (self , fileprefix , timbloptions , format = "Tabbed" , dist = True , encoding = 'utf-8' , overwrite = True , flushthreshold = 10000 , threading = False , normalize = True , debug = False , sklearn = False , flushdir = None ):
63
64
if format .lower () == "tabbed" :
64
65
self .format = "Tabbed"
65
66
self .delimiter = "\t "
66
67
elif format .lower () == "columns" :
67
68
self .format = "Columns"
68
69
self .delimiter = " "
70
+ elif format .lower () == 'sparse' : # for sparse arrays, e.g. scipy.sparse.csr
71
+ self .format = "Sparse"
72
+ self .delimiter = ""
69
73
else :
70
- raise ValueError ("Only Tabbed and Columns are supported input format for the python wrapper, not " + format )
74
+ raise ValueError ("Only Tabbed, Columns, and Sparse are supported input format for the python wrapper, not " + format )
75
+
71
76
self .timbloptions = timbloptions
72
77
self .fileprefix = fileprefix
73
78
@@ -80,11 +85,17 @@ def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encod
80
85
self .instances = []
81
86
self .api = None
82
87
self .debug = debug
88
+ self .sklearn = sklearn
83
89
84
- if os .path .exists (self .fileprefix + ".train" ) and overwrite :
90
+ if sklearn :
91
+ import scipy as sp
92
+ self .flushfile = mktemp (prefix = self .fileprefix , dir = flushdir )
85
93
self .flushed = 0
86
94
else :
87
- self .flushed = 1
95
+ if os .path .exists (self .fileprefix + ".train" ) and overwrite :
96
+ self .flushed = 0
97
+ else :
98
+ self .flushed = 1
88
99
89
100
self .threading = threading
90
101
@@ -94,8 +105,10 @@ def validatefeatures(self,features):
94
105
for feature in features :
95
106
if isinstance (feature , int ) or isinstance (feature , float ):
96
107
validatedfeatures .append ( str (feature ) )
97
- elif self .delimiter in feature :
108
+ elif self .delimiter in feature and not self . sklearn :
98
109
raise ValueError ("Feature contains delimiter: " + feature )
110
+ elif self .sklearn and isinstance (feature , str ): #then is sparse added together
111
+ validatedfeatures .append (feature )
99
112
else :
100
113
validatedfeatures .append (feature )
101
114
return validatedfeatures
@@ -106,21 +119,24 @@ def append(self, features, classlabel):
106
119
107
120
features = self .validatefeatures (features )
108
121
109
- if self .delimiter in classlabel :
122
+ if self .delimiter in classlabel and self . delimiter != '' :
110
123
raise ValueError ("Class label contains delimiter: " + self .delimiter )
111
124
112
- self .instances .append (self .delimiter .join (features ) + self .delimiter + classlabel )
125
+ self .instances .append (self .delimiter .join (features ) + ( self .delimiter if not self . delimiter == '' else ' ' ) + classlabel )
113
126
if len (self .instances ) >= self .flushthreshold :
114
127
self .flush ()
115
128
116
129
def flush (self ):
117
130
if self .debug : print ("Flushing..." ,file = sys .stderr )
118
131
if len (self .instances ) == 0 : return False
119
132
120
- if self . flushed :
121
- f = io .open (self .fileprefix + ".train" , 'a ' , encoding = self .encoding )
133
+ if hasattr ( self , 'flushfile' ) :
134
+ f = io .open (self .flushfile , 'w ' , encoding = self .encoding )
122
135
else :
123
- f = io .open (self .fileprefix + ".train" ,'w' , encoding = self .encoding )
136
+ if self .flushed :
137
+ f = io .open (self .fileprefix + ".train" ,'a' , encoding = self .encoding )
138
+ else :
139
+ f = io .open (self .fileprefix + ".train" ,'w' , encoding = self .encoding )
124
140
125
141
for instance in self .instances :
126
142
f .write (instance + "\n " )
@@ -135,8 +151,18 @@ def __delete__(self):
135
151
136
152
def train (self , save = False ):
137
153
self .flush ()
138
- if not os .path .exists (self .fileprefix + ".train" ):
139
- raise LoadException ("Training file '" + self .fileprefix + ".train' not found. Did you forget to add instances with append()?" )
154
+
155
+ if hasattr (self , 'flushfile' ):
156
+ if not os .path .exists (self .flushfile ):
157
+ raise LoadException ("Training file '" + self .flushfile + "' not found. Did you forget to add instances with append()?" )
158
+ else :
159
+ filepath = self .flushfile
160
+ else :
161
+ if not os .path .exists (self .fileprefix + ".train" ):
162
+ raise LoadException ("Training file '" + self .fileprefix + ".train' not found. Did you forget to add instances with append()?" )
163
+ else :
164
+ filepath = self .fileprefix + '.train'
165
+
140
166
options = "-F " + self .format + " " + self .timbloptions
141
167
if self .dist :
142
168
options += " +v+db +v+di"
@@ -149,7 +175,7 @@ def train(self, save=False):
149
175
print ("Enabling debug for timblapi" ,file = stderr )
150
176
self .api .enableDebug ()
151
177
152
- trainfile = self . fileprefix + ".train"
178
+ trainfile = filepath
153
179
self .api .learn (b (trainfile ))
154
180
if save :
155
181
self .save ()
@@ -168,7 +194,8 @@ def classify(self, features, allowtopdistribution=True):
168
194
169
195
if not self .api :
170
196
self .load ()
171
- testinstance = self .delimiter .join (features ) + self .delimiter + "?"
197
+
198
+ testinstance = self .delimiter .join (features ) + (self .delimiter if not self .delimiter == '' else ' ' ) + "?"
172
199
if self .dist :
173
200
if self .threading :
174
201
result , cls , distribution , distance = self .api .classify3safe (b (testinstance ), self .normalize , int (not allowtopdistribution ))
@@ -347,8 +374,3 @@ def _parsedistribution(self, instance, start=0, end =None):
347
374
348
375
return dist
349
376
350
-
351
-
352
-
353
-
354
-
0 commit comments