Skip to content

Commit 068a0b5

Browse files
committed
Merge branch 'master' of github.com:/proycon/python-timbl
2 parents 629fd8f + e3c8767 commit 068a0b5

File tree

3 files changed

+168
-19
lines changed

3 files changed

+168
-19
lines changed

README.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
:alt: Project Status: Active – The project has reached a stable, usable state and is being actively developed.
66
:target: https://www.repostatus.org/#active
77

8+
.. image:: https://zenodo.org/badge/8136669.svg
9+
:target: https://zenodo.org/badge/latestdoi/8136669
10+
811
======================
912
README: python-timbl
1013
======================
@@ -203,4 +206,6 @@ manually call the ``initthreading()`` method.
203206
Three TiMBL API methods print information to a standard C++ output stream object (ShowBestNeighbors, ShowOptions, ShowSettings, ShowSettings). In the Python interface, these methods will only work with Python (stream) objects that have a fileno method returning a valid file descriptor. Alternatively, three new methods are provided (bestNeighbo(u)rs, options, settings); these methods return the same information as a Python string object.
204207

205208

209+
**scikit-learn wrapper**
206210

211+
A wrapper for use in scikit-learn has been added. It was designed for use in scikit-learn Pipeline objects. The wrapper is not finished and has to date only been tested on sparse data. Note that TiMBL does not work well with large amounts of features. It is suggested to reduce the amount of features to a number below 100 to keep system performance reasonable. Use on servers with large amounts of memory and processing cores advised.

timbl.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
stderr = sys.stderr
2222
stdout = sys.stdout
2323

24+
from tempfile import mktemp
2425
import timblapi
2526
import io
2627
import os
@@ -59,15 +60,19 @@ def u(s, encoding = 'utf-8', errors='strict'):
5960

6061

6162
class TimblClassifier(object):
62-
def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encoding = 'utf-8', overwrite = True, flushthreshold=10000, threading=False, normalize=True, debug=False):
63+
def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encoding = 'utf-8', overwrite = True, flushthreshold=10000, threading=False, normalize=True, debug=False, sklearn=False, flushdir=None):
6364
if format.lower() == "tabbed":
6465
self.format = "Tabbed"
6566
self.delimiter = "\t"
6667
elif format.lower() == "columns":
6768
self.format = "Columns"
6869
self.delimiter = " "
70+
elif format.lower() == 'sparse': # for sparse arrays, e.g. scipy.sparse.csr
71+
self.format = "Sparse"
72+
self.delimiter = ""
6973
else:
70-
raise ValueError("Only Tabbed and Columns are supported input format for the python wrapper, not " + format)
74+
raise ValueError("Only Tabbed, Columns, and Sparse are supported input format for the python wrapper, not " + format)
75+
7176
self.timbloptions = timbloptions
7277
self.fileprefix = fileprefix
7378

@@ -80,11 +85,17 @@ def __init__(self, fileprefix, timbloptions, format = "Tabbed", dist=True, encod
8085
self.instances = []
8186
self.api = None
8287
self.debug = debug
88+
self.sklearn = sklearn
8389

84-
if os.path.exists(self.fileprefix + ".train") and overwrite:
90+
if sklearn:
91+
import scipy as sp
92+
self.flushfile = mktemp(prefix=self.fileprefix, dir=flushdir)
8593
self.flushed = 0
8694
else:
87-
self.flushed = 1
95+
if os.path.exists(self.fileprefix + ".train") and overwrite:
96+
self.flushed = 0
97+
else:
98+
self.flushed = 1
8899

89100
self.threading = threading
90101

@@ -94,8 +105,10 @@ def validatefeatures(self,features):
94105
for feature in features:
95106
if isinstance(feature, int) or isinstance(feature, float):
96107
validatedfeatures.append( str(feature) )
97-
elif self.delimiter in feature:
108+
elif self.delimiter in feature and not self.sklearn:
98109
raise ValueError("Feature contains delimiter: " + feature)
110+
elif self.sklearn and isinstance(feature, str): #then is sparse added together
111+
validatedfeatures.append(feature)
99112
else:
100113
validatedfeatures.append(feature)
101114
return validatedfeatures
@@ -106,21 +119,24 @@ def append(self, features, classlabel):
106119

107120
features = self.validatefeatures(features)
108121

109-
if self.delimiter in classlabel:
122+
if self.delimiter in classlabel and self.delimiter != '':
110123
raise ValueError("Class label contains delimiter: " + self.delimiter)
111124

112-
self.instances.append(self.delimiter.join(features) + self.delimiter + classlabel)
125+
self.instances.append(self.delimiter.join(features) + (self.delimiter if not self.delimiter == '' else ' ') + classlabel)
113126
if len(self.instances) >= self.flushthreshold:
114127
self.flush()
115128

116129
def flush(self):
117130
if self.debug: print("Flushing...",file=sys.stderr)
118131
if len(self.instances) == 0: return False
119132

120-
if self.flushed:
121-
f = io.open(self.fileprefix + ".train",'a', encoding=self.encoding)
133+
if hasattr(self, 'flushfile'):
134+
f = io.open(self.flushfile,'w', encoding=self.encoding)
122135
else:
123-
f = io.open(self.fileprefix + ".train",'w', encoding=self.encoding)
136+
if self.flushed:
137+
f = io.open(self.fileprefix + ".train",'a', encoding=self.encoding)
138+
else:
139+
f = io.open(self.fileprefix + ".train",'w', encoding=self.encoding)
124140

125141
for instance in self.instances:
126142
f.write(instance + "\n")
@@ -135,8 +151,18 @@ def __delete__(self):
135151

136152
def train(self, save=False):
137153
self.flush()
138-
if not os.path.exists(self.fileprefix + ".train"):
139-
raise LoadException("Training file '"+self.fileprefix+".train' not found. Did you forget to add instances with append()?")
154+
155+
if hasattr(self, 'flushfile'):
156+
if not os.path.exists(self.flushfile):
157+
raise LoadException("Training file '"+self.flushfile+"' not found. Did you forget to add instances with append()?")
158+
else:
159+
filepath = self.flushfile
160+
else:
161+
if not os.path.exists(self.fileprefix + ".train"):
162+
raise LoadException("Training file '"+self.fileprefix+".train' not found. Did you forget to add instances with append()?")
163+
else:
164+
filepath = self.fileprefix + '.train'
165+
140166
options = "-F " + self.format + " " + self.timbloptions
141167
if self.dist:
142168
options += " +v+db +v+di"
@@ -149,7 +175,7 @@ def train(self, save=False):
149175
print("Enabling debug for timblapi",file=stderr)
150176
self.api.enableDebug()
151177

152-
trainfile = self.fileprefix + ".train"
178+
trainfile = filepath
153179
self.api.learn(b(trainfile))
154180
if save:
155181
self.save()
@@ -168,7 +194,8 @@ def classify(self, features, allowtopdistribution=True):
168194

169195
if not self.api:
170196
self.load()
171-
testinstance = self.delimiter.join(features) + self.delimiter + "?"
197+
198+
testinstance = self.delimiter.join(features) + (self.delimiter if not self.delimiter == '' else ' ') + "?"
172199
if self.dist:
173200
if self.threading:
174201
result, cls, distribution, distance = self.api.classify3safe(b(testinstance), self.normalize, int(not allowtopdistribution))
@@ -347,8 +374,3 @@ def _parsedistribution(self, instance, start=0, end =None):
347374

348375
return dist
349376

350-
351-
352-
353-
354-

utils.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from sklearn.base import BaseEstimator, ClassifierMixin
2+
from sklearn.utils import check_X_y, check_array
3+
from timbl import TimblClassifier
4+
import scipy as sp
5+
import numpy as np
6+
7+
class skTiMBL(BaseEstimator, ClassifierMixin):
8+
def __init__(self, prefix='timbl', algorithm=4, dist_metric=None,
9+
k=1, normalize=False, debug=0, flushdir=None):
10+
self.prefix = prefix
11+
self.algorithm = algorithm
12+
self.dist_metric = dist_metric
13+
self.k = k
14+
self.normalize = normalize
15+
self.debug = debug
16+
self.flushdir = flushdir
17+
18+
19+
def _make_timbl_options(self, *options):
20+
"""
21+
-a algorithm
22+
-m metric
23+
-w weighting
24+
-k amount of neighbours
25+
-d class voting weights
26+
-L frequency threshold
27+
-T which feature index is label
28+
-N max number of features
29+
-H turn hashing on/off
30+
31+
This function still has to be made, for now the appropriate arguments
32+
can be passed in fit()
33+
"""
34+
pass
35+
36+
37+
def fit(self, X, y):
38+
X, y = check_X_y(X, y, dtype=np.int64, accept_sparse='csr')
39+
40+
n_rows = X.shape[0]
41+
self.classes_ = np.unique(y)
42+
43+
if sp.sparse.issparse(X):
44+
if self.debug: print('Features are sparse, choosing faster learning')
45+
46+
self.classifier = TimblClassifier(self.prefix, "-a{} -k{} -N{} -vf".format(self.algorithm,self.k, X.shape[1]),
47+
format='Sparse', debug=True, sklearn=True, flushdir=self.flushdir,
48+
flushthreshold=20000, normalize=self.normalize)
49+
50+
for i in range(n_rows):
51+
sparse = ['({},{})'.format(i+1, c) for i,c in zip(X[i].indices, X[i].data)]
52+
self.classifier.append(sparse,str(y[i]))
53+
54+
else:
55+
56+
self.classifier = TimblClassifier(self.prefix, "-a{} -k{} -N{} -vf".format(self.algorithm, self.k, X.shape[1]),
57+
debug=True, sklearn=True, flushdir=self.flushdir, flushthreshold=20000,
58+
normalize=self.normalize)
59+
60+
if y.dtype != 'O':
61+
y = y.astype(str)
62+
63+
for i in range(n_rows):
64+
self.classifier.append(list(X[i].toarray()[0]), y[i])
65+
66+
self.classifier.train()
67+
return self
68+
69+
70+
def _timbl_predictions(self, X, part_index, y=None):
71+
choices = {0 : lambda x : x.append(np.int64(label)),
72+
1 : lambda x : x.append([np.float(distance)]),
73+
}
74+
X = check_array(X, dtype=np.float64, accept_sparse='csr')
75+
76+
n_samples = X.shape[0]
77+
78+
pred = []
79+
func = choices[part_index]
80+
if sp.sparse.issparse(X):
81+
if self.debug: print('Features are sparse, choosing faster predictions')
82+
83+
for i in range(n_samples):
84+
sparse = ['({},{})'.format(i+1, c) for i,c in zip(X[i].indices, X[i].data)]
85+
label,proba, distance = self.classifier.classify(sparse)
86+
func(pred)
87+
88+
else:
89+
for i in range(n_samples):
90+
label,proba, distance = self.classifier.classify(list(X[i].toarray()[0]))
91+
func(pred)
92+
93+
return np.array(pred)
94+
95+
96+
97+
def predict(self, X, y=None):
98+
return self._timbl_predictions(X, part_index=0)
99+
100+
101+
def predict_proba(self, X, y=None):
102+
"""
103+
TIMBL is a discrete classifier. It cannot give probability estimations.
104+
To ensure that scikit-learn functions with TIMBL (and especially metrics
105+
such as ROC_AUC), this method is implemented.
106+
107+
For ROC_AUC, the classifier corresponds to a single point in ROC space,
108+
instead of a probabilistic continuum such as classifiers that can give
109+
a probability estimation (e.g. Linear classifiers). For an explanation,
110+
see Fawcett (2005).
111+
"""
112+
return predict(X)
113+
114+
115+
def decision_function(self, X, y=None):
116+
"""
117+
The decision function is interpreted here as being the distance between
118+
the instance that is being classified and the nearest point in k space.
119+
"""
120+
return self._timbl_predictions(X, part_index=1)
121+
122+

0 commit comments

Comments
 (0)