Skip to content

Commit a995374

Browse files
committed
Applied PR 8631
1 parent a9d2c26 commit a995374

File tree

7 files changed

+1049
-755
lines changed

7 files changed

+1049
-755
lines changed

go/cmd/pfdhcp/api.go

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/binary"
67
"encoding/json"
@@ -22,7 +23,8 @@ import (
2223
)
2324

2425
type API struct {
25-
DB *sql.DB
26+
DB *sql.DB
27+
Ctx context.Context
2628
}
2729

2830
// Node struct
@@ -136,7 +138,7 @@ func (a *API) handleAllStats(res http.ResponseWriter, req *http.Request) {
136138
}
137139
for _, i := range interfaces.Element {
138140
if h, ok := intNametoInterface[i]; ok {
139-
stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: i, NetWork: ""}, a.DB)
141+
stat := h.handleAPIReq(a.Ctx, APIReq{Req: "stats", NetInterface: i, NetWork: ""}, a.DB)
140142
for _, s := range stat.([]Stats) {
141143
result.Items = append(result.Items, s)
142144
}
@@ -159,7 +161,7 @@ func (a *API) handleStats(res http.ResponseWriter, req *http.Request) {
159161
vars := mux.Vars(req)
160162

161163
if h, ok := intNametoInterface[vars["int"]]; ok {
162-
stat := h.handleAPIReq(APIReq{Req: "stats", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB)
164+
stat := h.handleAPIReq(a.Ctx, APIReq{Req: "stats", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB)
163165

164166
outgoingJSON, err := json.Marshal(stat)
165167

@@ -180,7 +182,7 @@ func (a *API) handleDuplicates(res http.ResponseWriter, req *http.Request) {
180182
vars := mux.Vars(req)
181183

182184
if h, ok := intNametoInterface[vars["int"]]; ok {
183-
stat := h.handleAPIReq(APIReq{Req: "duplicates", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB)
185+
stat := h.handleAPIReq(a.Ctx, APIReq{Req: "duplicates", NetInterface: vars["int"], NetWork: vars["network"]}, a.DB)
184186

185187
outgoingJSON, err := json.Marshal(stat)
186188

@@ -201,7 +203,7 @@ func (a *API) handleDebug(res http.ResponseWriter, req *http.Request) {
201203
vars := mux.Vars(req)
202204

203205
if h, ok := intNametoInterface[vars["int"]]; ok {
204-
stat := h.handleAPIReq(APIReq{Req: "debug", NetInterface: vars["int"], Role: vars["role"]}, a.DB)
206+
stat := h.handleAPIReq(a.Ctx, APIReq{Req: "debug", NetInterface: vars["int"], Role: vars["role"]}, a.DB)
205207

206208
outgoingJSON, err := json.Marshal(stat)
207209

@@ -219,14 +221,14 @@ func (a *API) handleDebug(res http.ResponseWriter, req *http.Request) {
219221

220222
func (a *API) handleReleaseIP(res http.ResponseWriter, req *http.Request) {
221223
vars := mux.Vars(req)
222-
_ = InterfaceScopeFromMac(vars["mac"])
224+
_ = InterfaceScopeFromMac(a.Ctx, vars["mac"])
223225

224226
var result = &Info{Mac: vars["mac"], Status: "ACK"}
225227

226228
res.Header().Set("Content-Type", "application/json; charset=UTF-8")
227229
res.WriteHeader(http.StatusOK)
228230
if err := json.NewEncoder(res).Encode(result); err != nil {
229-
log.LoggerWContext(ctx).Error("Error releasing IP: " + err.Error())
231+
log.LoggerWContext(a.Ctx).Error("Error releasing IP: " + err.Error() + " mac=" + vars["mac"])
230232
}
231233
}
232234

@@ -243,14 +245,14 @@ func (a *API) handleOverrideOptions(res http.ResponseWriter, req *http.Request)
243245
}
244246

245247
// Insert information in MySQL
246-
_ = MysqlInsert(vars["mac"], sharedutils.ConvertToString(body), a.DB)
248+
_ = MysqlInsert(a.Ctx, vars["mac"], sharedutils.ConvertToString(body), a.DB)
247249

248250
var result = &Info{Mac: vars["mac"], Status: "ACK"}
249251

250252
res.Header().Set("Content-Type", "application/json; charset=UTF-8")
251253
res.WriteHeader(http.StatusOK)
252254
if err := json.NewEncoder(res).Encode(result); err != nil {
253-
log.LoggerWContext(ctx).Error("Error adding MAC options: " + err.Error())
255+
log.LoggerWContext(a.Ctx).Error("Error adding MAC options: " + err.Error() + " mac=" + vars["mac"])
254256
}
255257
}
256258

@@ -267,7 +269,7 @@ func (a *API) handleOverrideNetworkOptions(res http.ResponseWriter, req *http.Re
267269
}
268270

269271
// Insert information in MySQL
270-
_ = MysqlInsert(vars["network"], sharedutils.ConvertToString(body), a.DB)
272+
_ = MysqlInsert(a.Ctx, vars["network"], sharedutils.ConvertToString(body), a.DB)
271273

272274
var result = &Info{Network: vars["network"], Status: "ACK"}
273275

@@ -291,7 +293,7 @@ func (a *API) handleRemoveOptions(res http.ResponseWriter, req *http.Request) {
291293
res.Header().Set("Content-Type", "application/json; charset=UTF-8")
292294
res.WriteHeader(http.StatusOK)
293295
if err := json.NewEncoder(res).Encode(result); err != nil {
294-
log.LoggerWContext(ctx).Error("Error removing MAC options: " + err.Error())
296+
log.LoggerWContext(ctx).Error("Error removing MAC options: " + err.Error() + " mac=" + vars["mac"])
295297
}
296298
}
297299

@@ -312,9 +314,9 @@ func (a *API) handleRemoveNetworkOptions(res http.ResponseWriter, req *http.Requ
312314
}
313315
}
314316

315-
func decodeOptions(b string, db *sql.DB) (map[dhcp.OptionCode][]byte, error) {
317+
func decodeOptions(ctx context.Context, b string, db *sql.DB) (map[dhcp.OptionCode][]byte, error) {
316318
var options []Options
317-
_, value := MysqlGet(b, db)
319+
_, value := MysqlGet(ctx, b, db)
318320
decodedValue := sharedutils.ConvertToByte(value)
319321
var dhcpOptions = make(map[dhcp.OptionCode][]byte)
320322
if err := json.Unmarshal(decodedValue, &options); err != nil {
@@ -364,7 +366,7 @@ func extractMembers(v Network) ([]Node, []string, int) {
364366
return Members, Macs, Count
365367
}
366368

367-
func (h *Interface) handleAPIReq(Request APIReq, db *sql.DB) interface{} {
369+
func (h *Interface) handleAPIReq(ctx context.Context, Request APIReq, db *sql.DB) interface{} {
368370
var stats []Stats
369371

370372
if Request.Req == "duplicates" {
@@ -403,7 +405,7 @@ func (h *Interface) handleAPIReq(Request APIReq, db *sql.DB) interface{} {
403405
}
404406

405407
// Add network options on the fly
406-
x, err := decodeOptions(v.network.IP.String(), db)
408+
x, err := decodeOptions(ctx, v.network.IP.String(), db)
407409
if err == nil {
408410
for key, value := range x {
409411
Options[key.String()] = Tlv.Tlvlist[int(key)].Transform.String(value)

go/cmd/pfdhcp/config.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/binary"
67
"math"
@@ -69,7 +70,7 @@ func newDHCPConfig() *Interfaces {
6970
return &p
7071
}
7172

72-
func (d *Interfaces) readConfig(MyDB *sql.DB) {
73+
func (d *Interfaces) readConfig(ctx context.Context, MyDB *sql.DB) {
7374
interfaces := pfconfigdriver.GetType[pfconfigdriver.ListenInts](ctx)
7475
DHCPinterfaces := pfconfigdriver.GetType[pfconfigdriver.DHCPInts](ctx)
7576
portal := pfconfigdriver.GetType[pfconfigdriver.PfConfCaptivePortal](ctx)
@@ -268,7 +269,7 @@ func (d *Interfaces) readConfig(MyDB *sql.DB) {
268269
DHCPScope.xid = xid
269270
wg.Add(1)
270271
go func() {
271-
initiaLease(DHCPScope, ConfNet, MyDB)
272+
initiaLease(ctx, DHCPScope, ConfNet, MyDB)
272273
wg.Done()
273274
}()
274275
var options = make(map[dhcp.OptionCode][]byte)
@@ -348,15 +349,15 @@ func (d *Interfaces) readConfig(MyDB *sql.DB) {
348349
DHCPScope.xid = xid
349350
wg.Add(1)
350351
go func() {
351-
initiaLease(DHCPScope, ConfNet, MyDB)
352+
initiaLease(ctx, DHCPScope, ConfNet, MyDB)
352353
wg.Done()
353354
}()
354355

355356
var options = make(map[dhcp.OptionCode][]byte)
356357

357358
options[dhcp.OptionSubnetMask] = []byte(net.ParseIP(ConfNet.Netmask).To4())
358-
options[dhcp.OptionDomainNameServer] = ShuffleDNS(ConfNet)
359-
options[dhcp.OptionRouter] = ShuffleGateway(ConfNet)
359+
options[dhcp.OptionDomainNameServer] = ShuffleDNS(ctx, ConfNet)
360+
options[dhcp.OptionRouter] = ShuffleGateway(ctx, ConfNet)
360361
options[dhcp.OptionDomainName] = []byte(ConfNet.DomainName)
361362
if portal.SecureRedirect == "enabled" {
362363
options[dhcp.OptionCaptivePortal] = []byte(detectPortalURL(ConfNet, general))

go/cmd/pfdhcp/keysoption.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
package main
22

33
import (
4+
"context"
45
"database/sql"
6+
"time"
57

68
"github.com/inverse-inc/go-utils/log"
79
)
810

911
// MysqlInsert function
10-
func MysqlInsert(key string, value string, db *sql.DB) bool {
11-
if err := db.PingContext(ctx); err != nil {
12+
func MysqlInsert(ctx context.Context, key string, value string, db *sql.DB) bool {
13+
dbCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
14+
defer cancel()
15+
if err := db.PingContext(dbCtx); err != nil {
1216
log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error())
1317
}
18+
_, err := db.ExecContext(dbCtx,
1419

15-
_, err := db.Exec(
1620
`
1721
INSERT into key_value_storage values(?,?)
1822
ON DUPLICATE KEY UPDATE value = VALUES(value)
@@ -30,11 +34,14 @@ ON DUPLICATE KEY UPDATE value = VALUES(value)
3034
}
3135

3236
// MysqlGet function
33-
func MysqlGet(key string, db *sql.DB) (string, string) {
34-
if err := db.PingContext(ctx); err != nil {
37+
func MysqlGet(ctx context.Context, key string, db *sql.DB) (string, string) {
38+
dbCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
39+
defer cancel()
40+
if err := db.PingContext(dbCtx); err != nil {
3541
log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error())
3642
}
37-
rows, err := db.Query("select id, value from key_value_storage where id = ?", "/dhcpd/"+key)
43+
44+
rows, err := db.QueryContext(dbCtx, "select id, value from key_value_storage where id = ?", "/dhcpd/"+key)
3845
defer rows.Close()
3946
if err != nil {
4047
log.LoggerWContext(ctx).Debug("Error while getting MySQL '" + key + "': " + err.Error())
@@ -55,10 +62,12 @@ func MysqlGet(key string, db *sql.DB) (string, string) {
5562

5663
// MysqlDel function
5764
func MysqlDel(key string, db *sql.DB) bool {
58-
if err := db.PingContext(ctx); err != nil {
65+
dbCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
66+
defer cancel()
67+
if err := db.PingContext(dbCtx); err != nil {
5968
log.LoggerWContext(ctx).Error("Unable to ping database, reconnect: " + err.Error())
6069
}
61-
rows, err := db.Query("delete from key_value_storage where id = ?", "/dhcpd/"+key)
70+
rows, err := db.QueryContext(dbCtx, "delete from key_value_storage where id = ?", "/dhcpd/"+key)
6271
defer rows.Close()
6372
if err != nil {
6473
log.LoggerWContext(ctx).Error("Error while deleting MySQL key '" + key + "': " + err.Error())

0 commit comments

Comments
 (0)